From 03e85ade9f50bf8d77a87561ba6a0c8ed53bf6aa Mon Sep 17 00:00:00 2001 From: Patrick Guo Date: Mon, 25 Nov 2024 16:19:45 -0500 Subject: [PATCH] restructuring to use data driven approach --- pyrad/__init__.py | 2 +- pyrad/datatypes/__init__.py | 0 pyrad/datatypes/base.py | 61 ++++ pyrad/datatypes/leaf.py | 566 ++++++++++++++++++++++++++++++++++ pyrad/datatypes/structural.py | 131 ++++++++ pyrad/dictionary.py | 73 ++++- pyrad/packet.py | 105 +++++-- pyrad/parser.py | 88 ++++++ pyrad/tools.py | 255 --------------- pyrad/utility.py | 34 ++ tests/testDatatypes.py | 98 ++++++ tests/testDictionary.py | 36 +-- tests/testPacket.py | 2 +- tests/testTools.py | 127 -------- 14 files changed, 1126 insertions(+), 452 deletions(-) create mode 100644 pyrad/datatypes/__init__.py create mode 100644 pyrad/datatypes/base.py create mode 100644 pyrad/datatypes/leaf.py create mode 100644 pyrad/datatypes/structural.py create mode 100644 pyrad/parser.py delete mode 100644 pyrad/tools.py create mode 100644 pyrad/utility.py create mode 100644 tests/testDatatypes.py delete mode 100644 tests/testTools.py diff --git a/pyrad/__init__.py b/pyrad/__init__.py index 56f924e..c1ec1d3 100644 --- a/pyrad/__init__.py +++ b/pyrad/__init__.py @@ -43,4 +43,4 @@ __copyright__ = 'Copyright 2002-2023 Wichert Akkerman, Istvan Ruzman and Christian Giese. All rights reserved.' __version__ = '2.4' -__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'dictfile'] +__all__ = ['client', 'dictionary', 'packet', 'server', 'datatypes', 'dictfile'] diff --git a/pyrad/datatypes/__init__.py b/pyrad/datatypes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyrad/datatypes/base.py b/pyrad/datatypes/base.py new file mode 100644 index 0000000..c8d3d76 --- /dev/null +++ b/pyrad/datatypes/base.py @@ -0,0 +1,61 @@ +""" +base.py + +Contains base datatype +""" +from abc import ABC, abstractmethod + +class AbstractDatatype(ABC): + """ + Root of entire datatype class hierarchy + """ + def __init__(self, name): + self.name = name + + @abstractmethod + def encode(self, attribute, decoded, *args, **kwargs): + """ + turns python data structure into bytes + + :param *args: + :param **kwargs: + :param attribute: + :param decoded: python data structure to encode + :return: encoded bytes + """ + + @abstractmethod + def print(self, attribute, decoded, *args, **kwargs): + """ + returns string representation of decoding + + :param *args: + :param **kwargs: + :param attribute: attribute object + :param decoded: value pair + :return: string + """ + + @abstractmethod + def parse(self, dictionary, string, *args, **kwargs): + """ + returns python structure from ASCII string + + :param *args: + :param **kwargs: + :param dictionary: + :param string: ASCII string of attribute + :return: python structure for attribute + """ + + @abstractmethod + def get_value(self, attribute, packet, offset, *args, **kwargs): + """ + retrieves the encapsulated value + :param *args: + :param **kwargs: + :param attribute: attribute value + :param packet: packet + :param offset: attribute starting position + :return: encapsulated value, and bytes read + """ diff --git a/pyrad/datatypes/leaf.py b/pyrad/datatypes/leaf.py new file mode 100644 index 0000000..54540b2 --- /dev/null +++ b/pyrad/datatypes/leaf.py @@ -0,0 +1,566 @@ +""" +leaf.py + +Contains all leaf datatypes (ones that can be encoded and decoded directly) +""" +import binascii +import enum +import struct +from abc import ABC, abstractmethod +from datetime import datetime +from ipaddress import IPv4Address, IPv6Network, IPv6Address, IPv4Network, \ + AddressValueError + +from netaddr import EUI, core + +from pyrad.datatypes import base + + +class AbinaryKeystores(enum.Enum): + PYRAD = 1 + FREERADIUS = 2 + +class AbstractLeaf(base.AbstractDatatype, ABC): + """ + abstract class for leaf datatypes + """ + @abstractmethod + def decode(self, raw, *args, **kwargs): + """ + turns bytes into python data structure + + :param *args: + :param **kwargs: + :param raw: bytes + :return: python data structure + """ + + def get_value(self, attribute, packet, offset, *args, **kwargs): + _, attr_len = struct.unpack('!BB', packet[offset:offset + 2])[0:2] + return packet[offset + 2:offset + attr_len], attr_len + +class AscendBinary(AbstractLeaf): + """ + leaf datatype class for ascend binary + """ + def __init__(self): + super().__init__('abinary') + + def encode(self, attribute, decoded, *args, **kwargs): + # unless specified throw kwargs, use the pyrad keystore to avoid + # causing breakages + keystore = AbinaryKeystores.PYRAD + if 'keystore_abinary' in kwargs: + keystore = kwargs['keystore_abinary'] + + match keystore: + case AbinaryKeystores.PYRAD: + terms = { + 'family': b'\x01', + 'action': b'\x00', + 'direction': b'\x01', + 'src': b'\x00\x00\x00\x00', + 'dst': b'\x00\x00\x00\x00', + 'srcl': b'\x00', + 'dstl': b'\x00', + 'proto': b'\x00', + 'sport': b'\x00\x00', + 'dport': b'\x00\x00', + 'sportq': b'\x00', + 'dportq': b'\x00' + } + case AbinaryKeystores.FREERADIUS: + terms = { + 'srcip': b'=x00', + 'dstip': None, + 'srcmask': '', + 'dstmask': '', + 'proto': '', + 'established': '', + 'srcport': '', + 'dstport': '', + 'srcPortCmp': '', + 'dstPortCmp': '', + 'fill': '' + } + case _: + raise ValueError('Unexpected abinary keystore name') + + + family = 'ipv4' + for t in decoded.split(' '): + key, value = t.split('=') + if key == 'family' and value == 'ipv6': + family = 'ipv6' + terms[key] = b'\x03' + if terms['src'] == b'\x00\x00\x00\x00': + terms['src'] = 16 * b'\x00' + if terms['dst'] == b'\x00\x00\x00\x00': + terms['dst'] = 16 * b'\x00' + elif key == 'action' and value == 'accept': + terms[key] = b'\x01' + elif key == 'action' and value == 'redirect': + terms[key] = b'\x20' + elif key == 'direction' and value == 'out': + terms[key] = b'\x00' + elif key in ('src', 'dst'): + if family == 'ipv4': + ip = IPv4Network(value) + else: + ip = IPv6Network(value) + terms[key] = ip.network_address.packed + terms[key + 'l'] = struct.pack('B', ip.prefixlen) + elif key in ('sport', 'dport'): + terms[key] = struct.pack('!H', int(value)) + elif key in ('sportq', 'dportq', 'proto'): + terms[key] = struct.pack('B', int(value)) + + trailer = 8 * b'\x00' + + result = b''.join( + (terms['family'], terms['action'], terms['direction'], b'\x00', + terms['src'], terms['dst'], terms['srcl'], terms['dstl'], + terms['proto'], b'\x00', + terms['sport'], terms['dport'], terms['sportq'], terms['dportq'], + b'\x00\x00', trailer)) + return result + + def decode(self, raw, *args, **kwargs): + # just return the raw binary string + return raw + + def print(self, attribute, decoded): + # the binary string is what we are looking for + return decoded + + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + return string + +class Byte(AbstractLeaf): + """ + leaf datatype class for bytes (1 byte unsigned int) + """ + def __init__(self): + super().__init__('byte') + + def encode(self, attribute, decoded, *args, **kwargs): + try: + num = int(decoded) + except Exception as exc: + raise TypeError('Can not encode non-integer as byte') from exc + return struct.pack('!B', num) + + def decode(self, raw, *args, **kwargs): + return struct.unpack('!B', raw)[0] + + def print(self, attribute, decoded, *args, **kwargs): + # cast int to string before returning + return str(decoded) + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + num = int(string) + except ValueError as e: + raise TypeError('Can not parse non-integer as byte') from e + else: + if num < 0: + raise ValueError('Parsed value too small for byte') + if num > 255: + raise ValueError('Parsed value too large for byte') + return num + +class Date(AbstractLeaf): + """ + leaf datatype class for dates + """ + def __init__(self): + super().__init__('date') + + def encode(self, attribute, decoded, *args, **kwargs): + if not isinstance(decoded, int): + raise TypeError('Can not encode non-integer as date') + return struct.pack('!I', decoded) + + def decode(self, raw, *args, **kwargs): + # dates are stored as ints + return (struct.unpack('!I', raw))[0] + + def print(self, attribute, decoded, *args, **kwargs): + # turn seconds since epoch into timestamp with given format + return datetime.fromtimestamp(decoded).strftime('%Y-%m-%dT%H:%M:%S') + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + # parse string using given string, and return seconds since epoch + # as an int + return int(datetime.strptime(string, '%Y-%m-%dT%H:%M:%S') + .timestamp()) + except ValueError as e: + raise TypeError('Failed to parse date') from e + +class Ether(AbstractLeaf, ABC): + """ + leaf datatype class for ethernet addresses + """ + def __init__(self): + super().__init__('ether') + + def encode(self, attribute, decoded, *args, **kwargs): + return struct.pack('!6B', *map(lambda x: int(x, 16), decoded.split(':'))) + + def decode(self, raw, *args, **kwargs): + # return EUI object containing mac address + return EUI(':'.join(map('{0:02x}'.format, struct.unpack('!6B', raw)))) + + def print(self, attribute, decoded, *args, **kwargs): + return decoded + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError('Can not encode non-string as ethernet address') + + try: + return EUI(string) + except core.AddrFormatError as e: + raise ValueError('Could not decode ethernet address') from e + +class Ifid(AbstractLeaf, ABC): + """ + leaf datatype class for IFID (IPV6 interface ID) + """ + def __init__(self): + super().__init__('ifid') + + def encode(self, attribute, decoded, *args, **kwargs): + struct.pack('!HHHH', *map(lambda x: int(x, 16), decoded.split(':'))) + + def decode(self, raw, *args, **kwargs): + ':'.join(map('{0:04x}'.format, struct.unpack('!HHHH', raw))) + + def print(self, attribute, decoded, *args, **kwargs): + # Following freeradius, IFIDs are displayed as a hex without any + # delimiters + return decoded.replace(':', '') + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + # adds a : delimiter after every second character + return ':'.join((string[i:i + 2] for i in range(0, len(string), 2))) + +class Integer(AbstractLeaf): + """ + leaf datatype class for integers + """ + def __init__(self): + super().__init__('integer') + + def encode(self, attribute, decoded, *args, **kwargs): + try: + num = int(decoded) + except Exception as exc: + raise TypeError('Can not encode non-integer as integer') from exc + return struct.pack('!I', num) + + def decode(self, raw, *args, **kwargs): + return struct.unpack('!I', raw)[0] + + def print(self, attribute, decoded, *args, **kwargs): + return str(decoded) + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + num = int(string) + except ValueError as e: + raise TypeError('Can not parse non-integer as int') from e + else: + if num < 0: + raise ValueError('Parsed value too small for int') + if num > 4294967295: + raise ValueError('Parsed value too large for int') + return num + +class Integer64(AbstractLeaf): + """ + leaf datatype class for 64bit integers + """ + def __init__(self): + super().__init__('integer64') + + def encode(self, attribute, decoded, *args, **kwargs): + try: + num = int(decoded) + except Exception as exc: + raise TypeError('Can not encode non-integer as 64bit integer') from exc + return struct.pack('!Q', num) + + def decode(self, raw, *args, **kwargs): + return struct.unpack('!Q', raw)[0] + + def print(self, attribute, decoded, *args, **kwargs): + return str(decoded) + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + num = int(string) + except ValueError as e: + raise TypeError('Can not parse non-integer as int64') from e + else: + if num < 0: + raise ValueError('Parsed value too small for int64') + if num > 18446744073709551615: + raise ValueError('Parsed value too large for int64') + return num + +class Ipaddr(AbstractLeaf): + """ + leaf datatype class for ipv4 addresses + """ + def __init__(self): + super().__init__('ipaddr') + + def encode(self, attribute, decoded, *args, **kwargs): + if not isinstance(decoded, str): + raise TypeError('Address has to be a string') + return IPv4Address(decoded).packed + + def decode(self, raw, *args, **kwargs): + # stored as strings, not ipaddress objects + return '.'.join(map(str, struct.unpack('BBBB', raw))) + + def print(self, attribute, decoded, *args, **kwargs): + # since object is already stored as a string, just return it as is + return decoded + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + # check if string is valid ipv4 address, but still returning the + # string representation + return IPv4Address(string).exploded + except AddressValueError as e: + raise TypeError('Parsing invalid IPv4 address') from e + +class Ipv6addr(AbstractLeaf): + """ + leaf datatype class for ipv6 addresses + """ + def __init__(self): + super().__init__('ipv6addr') + + def encode(self, attribute, decoded, *args, **kwargs): + if not isinstance(decoded, str): + raise TypeError('IPv6 Address has to be a string') + return IPv6Address(decoded).packed + + def decode(self, raw, *args, **kwargs): + addr = raw + b'\x00' * (16 - len(raw)) + prefix = ':'.join( + map(lambda x: f'{0:x}', struct.unpack('!' + 'H' * 8, addr)) + ) + return str(IPv6Address(prefix)) + + def print(self, attribute, decoded, *args, **kwargs): + if not isinstance(decoded, str): + raise TypeError(f'Parsing expects a string, got {type(decoded)}') + + try: + # check if valid address, but return string representation + return IPv6Address(decoded).exploded + except AddressValueError as e: + raise TypeError('Parsing invalid IPv6 address') from e + + def parse(self, dictionary, string, *args, **kwargs): + return string + +class Ipv6prefix(AbstractLeaf): + """ + leaf datatype class for ipv6 prefixes + """ + def __init__(self): + super().__init__('ipv6prefix') + + def encode(self, attribute, decoded, *args, **kwargs): + if not isinstance(decoded, str): + raise TypeError('IPv6 Prefix has to be a string') + ip = IPv6Network(decoded) + return (struct.pack('2B', *[0, ip.prefixlen]) + + ip.network_address.packed) + + def decode(self, raw, *args, **kwargs): + addr = raw + b'\x00' * (18 - len(raw)) + _, length, prefix = ':'.join( + map(lambda x: f'{0:x}' , struct.unpack('!BB' + 'H' * 8, addr)) + ).split(":", 2) + # returns string representation in the form of / + return str(IPv6Network(f'{prefix}/{int(length, 16)}')) + + def print(self, attribute, decoded, *args, **kwargs): + # we already store this value as a string, so just return it as is + return decoded + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + return str(IPv6Network(string)) + except AddressValueError as e: + raise TypeError('Parsing invalid IPv6 prefix') from e + +class Octets(AbstractLeaf): + """ + leaf datatype class for octets + """ + def __init__(self): + super().__init__('octets') + + def encode(self, attribute, decoded, *args, **kwargs): + # Check for max length of the hex encoded with 0x prefix, as a sanity check + if len(decoded) > 508: + raise ValueError('Can only encode strings of <= 253 characters') + + if isinstance(decoded, bytes) and decoded.startswith(b'0x'): + hexstring = decoded.split(b'0x')[1] + encoded_octets = binascii.unhexlify(hexstring) + elif isinstance(decoded, str) and decoded.startswith('0x'): + hexstring = decoded.split('0x')[1] + encoded_octets = binascii.unhexlify(hexstring) + elif isinstance(decoded, str) and decoded.isdecimal(): + encoded_octets = struct.pack('>L', int(decoded)).lstrip( + b'\x00') + else: + encoded_octets = decoded + + # Check for the encoded value being longer than 253 chars + if len(encoded_octets) > 253: + raise ValueError('Can only encode strings of <= 253 characters') + + return encoded_octets + + def decode(self, raw, *args, **kwargs): + return raw + + def print(self, attribute, decoded, *args, **kwargs): + return decoded + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + return string + +class Short(AbstractLeaf): + """ + leaf datatype class for short integers + """ + def __init__(self): + super().__init__('short') + + def encode(self, attribute, decoded, *args, **kwargs): + try: + num = int(decoded) + except Exception as exc: + raise TypeError('Can not encode non-integer as integer') from exc + return struct.pack('!H', num) + + def decode(self, raw, *args, **kwargs): + return struct.unpack('!H', raw)[0] + + def print(self, attribute, decoded, *args, **kwargs): + return str(decoded) + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + num = int(string) + except ValueError as e: + raise TypeError('Can not parse non-integer as short') from e + else: + if num < 0: + raise ValueError('Parsed value too small for short') + if num > 65535: + raise ValueError('Parsed value too large for short') + return num + +class Signed(AbstractLeaf): + """ + leaf datatype class for signed integers + """ + def __init__(self): + super().__init__('signed') + + def encode(self, attribute, decoded, *args, **kwargs): + try: + num = int(decoded) + except Exception as exc: + raise TypeError('Can not encode non-integer as signed integer') from exc + return struct.pack('!i', num) + + def decode(self, raw, *args, **kwargs): + return struct.unpack('!i', raw)[0] + + def print(self, attribute, decoded, *args, **kwargs): + return str(decoded) + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + num = int(string) + except ValueError as e: + raise TypeError('Can not parse non-integer as signed') from e + else: + if num < -2147483648: + raise ValueError('Parsed value too small for signed') + if num > 2147483647: + raise ValueError('Parsed value too large for signed') + return num + +class String(AbstractLeaf): + """ + leaf datatype class for strings + """ + def __init__(self): + super().__init__('string') + + def encode(self, attribute, decoded, *args, **kwargs): + if len(decoded) > 253: + raise ValueError('Can only encode strings of <= 253 characters') + if isinstance(decoded, str): + return decoded.encode('utf-8') + return decoded + + def decode(self, raw, *args, **kwargs): + return raw.decode('utf-8') + + def print(self, attribute, decoded, *args, **kwargs): + return decoded + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + return string diff --git a/pyrad/datatypes/structural.py b/pyrad/datatypes/structural.py new file mode 100644 index 0000000..536b4c0 --- /dev/null +++ b/pyrad/datatypes/structural.py @@ -0,0 +1,131 @@ +""" +structural.py + +Contains all structural datatypes +""" +import struct + +from abc import ABC +from pyrad.datatypes import base +from pyrad.parser import ParserTLV +from pyrad.utility import tlv_name_to_codes, vsa_name_to_codes + +parser_tlv = ParserTLV() + +class AbstractStructural(base.AbstractDatatype, ABC): + """ + abstract class for structural datatypes + """ + +class Tlv(AbstractStructural): + """ + structural datatype class for TLV + """ + def __init__(self): + super().__init__('tlv') + + def encode(self, attribute, decoded, *args, **kwargs): + encoding = b'' + for key, value in decoded.items(): + encoding += attribute.sub_attributes[key].encode(value, ) + + if len(encoding) + 2 > 255: + raise ValueError('TLV length too long for one packet') + + return (struct.pack('!B', attribute.code) + + struct.pack('!B', len(encoding) + 2) + + encoding) + + def get_value(self, attribute: 'Attribute', packet, offset, *args, + **kwargs): + sub_attrs = {} + + _, outer_len = struct.unpack('!BB', packet[offset:offset + 2])[0:2] + + if outer_len < 3: + raise ValueError('TLV length too short') + if offset + outer_len > len(packet): + raise ValueError('TLV length too long') + + # move cursor to TLV value + cursor = offset + 2 + while cursor < offset + outer_len: + sub_type, sub_len = struct.unpack( + '!BB', packet[cursor:cursor + 2] + )[0:2] + + if sub_len < 3: + raise ValueError('TLV length field too small') + + value, subattr_offset = attribute.sub_attributes[sub_type].type.get_value( + attribute, packet, cursor) + sub_attrs.setdefault(sub_type, []).append(value) + cursor += subattr_offset + return sub_attrs, outer_len + + def print(self, attribute, decoded, *args, **kwargs): + sub_attr_strings = [sub_attr.print() + for sub_attr in attribute.sub_attributes] + return f"{attribute.name} = {{ {', '.join(sub_attr_strings)} }}" + + def parse(self, dictionary, string, *args, **kwargs): + return tlv_name_to_codes(dictionary, parser_tlv.parse(string)) + +class Vsa(AbstractStructural): + """ + structural datatype class for VSA + """ + def __init__(self): + super().__init__('vsa') + + def encode(self, attribute, decoded, *args, **kwargs): + encoding = b'' + + for key, value in decoded.items(): + encoding += attribute.sub_attributes[key].encode(value, ) + + return (struct.pack('!B', attribute.code) + + struct.pack('!B', len(encoding) + 4) + + struct.pack('!L', attribute.vendor) + + encoding) + + def get_value(self, attribute, packet, offset, *args, **kwargs): + sub_attrs = {} + + _, outer_len = struct.unpack( + '!BB', packet[offset:offset + 2] + )[0:2] + + if outer_len < 8: + # in malformed packets, take everything after the outlet len as + # the vendor name and set the tlv to be empty + return {packet[offset + 2:offset + outer_len]: {}}, outer_len + if offset + outer_len > len(packet): + raise ValueError('VSA length too long') + + vendor_id = struct.unpack('!L', packet[offset + 2:offset + 6])[0] + + cursor = offset + 6 + while cursor < offset + outer_len: + sub_type, sub_len = struct.unpack( + '!BB', packet[cursor:cursor + 2] + )[0:2] + + if sub_len < 3: + raise ValueError('TLV length field too small') + + sub_attr = attribute.sub_attributes[vendor_id][sub_type] + + value, offset = sub_attr.type.get_value(sub_attr, packet, cursor) + sub_attrs.setdefault(sub_type, []).append(value) + cursor += offset + + return {vendor_id: sub_attrs}, outer_len + + def print(self, attribute, decoded, *args, **kwargs): + sub_attr_strings = [sub_attr.print() + for sub_attr in attribute.sub_attributes] + return f"Vendor-Specific = {{ {attribute.vendor} = {{ {', '.join(sub_attr_strings)} }}" + + def parse(self, dictionary, string, *args, **kwargs): + return vsa_name_to_codes(dictionary, parser_tlv.parse(string)) diff --git a/pyrad/dictionary.py b/pyrad/dictionary.py index abe5263..a0ddd6e 100644 --- a/pyrad/dictionary.py +++ b/pyrad/dictionary.py @@ -71,19 +71,36 @@ | | where 'h' is hex digits, upper or lowercase. | +---------------+----------------------------------------------+ """ + +from copy import copy + from pyrad import bidict -from pyrad import tools from pyrad import dictfile -from copy import copy -import logging +from pyrad.datatypes import structural +from pyrad.datatypes import leaf __docformat__ = 'epytext en' - -DATATYPES = frozenset(['string', 'ipaddr', 'integer', 'date', 'octets', - 'abinary', 'ipv6addr', 'ipv6prefix', 'short', 'byte', - 'signed', 'ifid', 'ether', 'tlv', 'integer64']) - +from pyrad.datatypes.structural import AbstractStructural + +DATATYPES = { + 'abinary': leaf.AscendBinary(), + 'byte': leaf.Byte(), + 'date': leaf.Date(), + 'ether': leaf.Ether(), + 'ifid': leaf.Ifid(), + 'integer': leaf.Integer(), + 'integer64': leaf.Integer64(), + 'ipaddr': leaf.Ipaddr(), + 'ipv6addr': leaf.Ipv6addr(), + 'ipv6prefix': leaf.Ipv6prefix(), + 'octets': leaf.Octets(), + 'short': leaf.Short(), + 'signed': leaf.Signed(), + 'string': leaf.String(), + 'tlv': structural.Tlv(), + 'vsa': structural.Vsa() +} class ParseError(Exception): """Dictionary parser exceptions. @@ -115,13 +132,14 @@ def __str__(self): class Attribute(object): - def __init__(self, name, code, datatype, is_sub_attribute=False, vendor='', values=None, - encrypt=0, has_tag=False): + def __init__(self, name, code, datatype: str, + is_sub_attribute=False, vendor='', values=None, encrypt=0, + has_tag=False): if datatype not in DATATYPES: raise ValueError('Invalid data type') self.name = name self.code = code - self.type = datatype + self.type = DATATYPES[datatype] self.vendor = vendor self.encrypt = encrypt self.has_tag = has_tag @@ -133,6 +151,24 @@ def __init__(self, name, code, datatype, is_sub_attribute=False, vendor='', valu for (key, value) in values.items(): self.values.Add(key, value) + def decode(self, raw): + # Use datatype.decode to decode leaf attributes + if isinstance(raw, bytes): + if isinstance(self.type, AbstractStructural): + raise ValueError('Structural datatype holding string!') + return self.type.decode(raw) + + # Recursively calls sub attribute's .decode() until a leaf attribute + # is reached + for sub_attr, value in raw.items(): + raw[sub_attr] = self.sub_attributes[sub_attr].decode(value) + return raw + + def encode(self, decoded): + return self.type.encode(self, decoded) + + def get_value(self, packet, offset): + return self.type.get_value(self, packet, offset) class Dictionary(object): """RADIUS dictionary class. @@ -250,23 +286,27 @@ def keyval(o): line=state['line']) if vendor: if is_sub_attribute: - key = (self.vendors.GetForward(vendor), parent_code, code) + key = (26, self.vendors.GetForward(vendor), parent_code, code) else: - key = (self.vendors.GetForward(vendor), code) + key = (26, self.vendors.GetForward(vendor), code) else: if is_sub_attribute: key = (parent_code, code) else: key = code + attr = Attribute(attribute, code, datatype, is_sub_attribute, vendor, encrypt=encrypt, has_tag=has_tag) + self.attrindex.Add(attribute, key) - self.attributes[attribute] = Attribute(attribute, code, datatype, is_sub_attribute, vendor, encrypt=encrypt, has_tag=has_tag) + self.attributes[attribute] = attr + if vendor: + self.attributes['Vendor-Specific'].sub_attributes[self.vendors.GetForward(vendor)][code] = attr if datatype == 'tlv': # save attribute in tlvs state['tlvs'][code] = self.attributes[attribute] if is_sub_attribute: # save sub attribute in parent tlv and update their parent field - state['tlvs'][parent_code].sub_attributes[code] = attribute + state['tlvs'][parent_code].sub_attributes[code] = self.attributes[attribute] self.attributes[attribute].parent = state['tlvs'][parent_code] def __ParseValue(self, state, tokens, defer): @@ -289,7 +329,7 @@ def __ParseValue(self, state, tokens, defer): if adef.type in ['integer', 'signed', 'short', 'byte', 'integer64']: value = int(value, 0) - value = tools.EncodeAttr(adef.type, value) + value = adef.encode(value) self.attributes[attr].values.Add(key, value) def __ParseVendor(self, state, tokens): @@ -323,6 +363,7 @@ def __ParseVendor(self, state, tokens): (vendorname, vendor) = tokens[1:3] self.vendors.Add(vendorname, int(vendor, 0)) + self.attributes['Vendor-Specific'].sub_attributes[int(vendor)] = {} def __ParseBeginVendor(self, state, tokens): if len(tokens) != 2: diff --git a/pyrad/packet.py b/pyrad/packet.py index 4564f8f..750c02b 100644 --- a/pyrad/packet.py +++ b/pyrad/packet.py @@ -4,8 +4,13 @@ # # A RADIUS packet as defined in RFC 2138 -from collections import OrderedDict import struct +from collections import OrderedDict + +from pyrad.datatypes.leaf import Octets, Integer +from pyrad.datatypes.structural import Tlv, Vsa +from pyrad.dictionary import Dictionary, Attribute + try: import secrets random_generator = secrets.SystemRandom() @@ -27,7 +32,6 @@ # BBB for python 2.4 import md5 md5_constructor = md5.new -from pyrad import tools # Packet codes AccessRequest = 1 @@ -101,7 +105,7 @@ def __init__(self, code=0, id=None, secret=b'', authenticator=None, self.raw_packet = None if 'dict' in attributes: - self.dict = attributes['dict'] + self.dict: Dictionary = attributes['dict'] if 'packet' in attributes: self.raw_packet = attributes['packet'] @@ -248,14 +252,14 @@ def _DecodeValue(self, attr, value): if attr.values.HasBackward(value): return attr.values.GetBackward(value) else: - return tools.DecodeAttr(attr.type, value) + return attr.decode(value) def _EncodeValue(self, attr, value): result = '' if attr.values.HasForward(value): result = attr.values.GetForward(value) else: - result = tools.EncodeAttr(attr.type, value) + result = attr.encode(value) if attr.encrypt == 2: # salt encrypt attribute @@ -275,7 +279,7 @@ def _EncodeKeyValues(self, key, values): key = self._EncodeKey(key) if tag: tag = struct.pack('B', int(tag)) - if attr.type == "integer": + if isinstance(attr.type, Integer): return (key, [tag + self._EncodeValue(attr, v)[1:] for v in values]) else: return (key, [tag + self._EncodeValue(attr, v) for v in values]) @@ -333,10 +337,10 @@ def __getitem__(self, key): values = OrderedDict.__getitem__(self, self._EncodeKey(key)) attr = self.dict.attributes[key] - if attr.type == 'tlv': # return map from sub attribute code to its values + if isinstance(attr.type, Tlv): # return map from sub attribute code to its values res = {} for (sub_attr_key, sub_attr_val) in values.items(): - sub_attr_name = attr.sub_attributes[sub_attr_key] + sub_attr_name = attr.sub_attributes[sub_attr_key].name sub_attr = self.dict.attributes[sub_attr_name] for v in sub_attr_val: res.setdefault(sub_attr_name, []).append(self._DecodeValue(sub_attr, v)) @@ -485,7 +489,7 @@ def _PktEncodeAttributes(self): result = b'' for (code, datalst) in self.items(): attribute = self.dict.attributes.get(self._DecodeKey(code)) - if attribute and attribute.type == 'tlv': + if isinstance(attribute.type, Tlv): result += self._PktEncodeTlv(code, datalst) else: for data in datalst: @@ -501,7 +505,7 @@ def _PktDecodeVendorAttribute(self, data): (vendor, atype, length) = struct.unpack('!LBB', data[:6])[0:3] attribute = self.dict.attributes.get(self._DecodeKey((vendor, atype))) try: - if attribute and attribute.type == 'tlv': + if isinstance(attribute.type, Tlv): self._PktDecodeTlvAttribute((vendor, atype), data[6:length + 4]) tlvs = [] # tlv is added to the packet inside _PktDecodeTlvAttribute else: @@ -533,7 +537,15 @@ def DecodePacket(self, packet): received from the network and decode it. :param packet: raw packet - :type packet: string""" + :type packet: bytestring""" + + # the presence of some attributes require us to perform certain + # actions. this dict maps the attribute names to the functions to + # perform those actions + attr_actions = { + 'Message-Authenticator': self.__attr_action_message_authenticator + } + try: (self.code, self.id, length, self.authenticator) = \ @@ -548,33 +560,62 @@ def DecodePacket(self, packet): self.clear() - packet = packet[20:] - while packet: + cursor = 20 + while cursor < len(packet): try: - (key, attrlen) = struct.unpack('!BB', packet[0:2]) + (key, length) = struct.unpack('!BB', packet[cursor:cursor + 2]) except struct.error: raise PacketError('Attribute header is corrupt') - if attrlen < 2: - raise PacketError( - 'Attribute length is too small (%d)' % attrlen) - - value = packet[2:attrlen] - attribute = self.dict.attributes.get(self._DecodeKey(key)) - if key == 26: - for (key, value) in self._PktDecodeVendorAttribute(value): - self.setdefault(key, []).append(value) - elif key == 80: - # POST: Message Authenticator AVP is present. - self.message_authenticator = True - self.setdefault(key, []).append(value) - elif attribute and attribute.type == 'tlv': - self._PktDecodeTlvAttribute(key,value) + if length < 2: + raise PacketError(f'Attribute length is too small {length}') + + attribute: Attribute = self.dict.attributes.get(self._DecodeKey(key)) + + # perform attribute actions as needed + if attribute.name in attr_actions: + attr_actions[attribute.name](attribute, packet, cursor) + + if attribute is None: + raise PacketError(f'Unknown attribute key {key}') + + raw, offset = attribute.get_value(packet, cursor) + + # TODO :: move this VSA specific logic away from here + if isinstance(attribute.type, Vsa): + vsa = self.setdefault(attribute.code, {}) + vendor_id = list(raw.keys())[0] + attrs = vsa.setdefault(vendor_id, {}) + self[26] = self.__vendor_merge(attrs, raw) else: - self.setdefault(key, []).append(value) + self.setdefault(attribute.code, []).append(raw) + + cursor += offset + + def __attr_action_message_authenticator(self, attribute, packet, offset): + # if the Message-Authenticator attribute is present, set the + # class attribute to True + self.message_authenticator = True - packet = packet[attrlen:] + def __vendor_merge(self, vendor, raw): + results = {} + + all_keys = set(vendor.keys()).union(raw.keys()) + + for key in all_keys: + vendor_val = vendor.get(key) + raw_val = raw.get(key) + + if isinstance(vendor_val, dict) and isinstance(raw_val, dict): + results[key] = self.__vendor_merge(vendor_val, raw_val) + elif isinstance(vendor_val, list): + results[key] = vendor_val + [raw_val] + elif vendor_val is not None: + results[key] = [vendor_val, raw_val] + else: + results[key] = raw_val + return results def _salt_en_decrypt(self, data, salt): result = b'' @@ -796,7 +837,7 @@ def VerifyChapPasswd(self, userpwd): if isinstance(userpwd, str): userpwd = userpwd.strip().encode('utf-8') - chap_password = tools.DecodeOctets(self.get(3)[0]) + chap_password = Octets().decode(self.get(3)[0]) if len(chap_password) != 17: return False diff --git a/pyrad/parser.py b/pyrad/parser.py new file mode 100644 index 0000000..78f0d13 --- /dev/null +++ b/pyrad/parser.py @@ -0,0 +1,88 @@ +""" +BNF form of string TLVs + + ::= + ::= " = " ( | ("{ " " }")) + ::= (", " )* + ::= ([A-Z] | [a-z] | [0-9])+ +""" + +class ParseError(Exception): + pass + +class ParserTLV: + """ + Recursive descent parser for TLVs (and similar structural datatypes) + """ + def __init__(self): + self.__buffer: str = None + self.__cursor: int = None + + def parse(self, buffer): + self.__buffer = buffer + self.__cursor = 0 + + return self.__state_vp() + + def __state_vp(self): + vp = {} + + # get key for current vp + key = self.__state_string() + + # check for and move past '=' token + if not self.__buffer[self.__cursor] == '=': + raise ParseError('Did not find equal sign at position') + self.__cursor += 1 + self.__remove_whitespace() + + if self.__buffer[self.__cursor] == '{': + # move past '{' token + self.__cursor += 1 + self.__remove_whitespace() + + value = self.__state_vps() + + # check for and move past '}' token + if not self.__buffer[self.__cursor] == '}': + raise ParseError('Did not find closing bracket') + self.__cursor += 1 + self.__remove_whitespace() + else: + value = self.__state_string() + + vp[key] = value + return vp + + def __state_vps(self): + vps = {} + while True: + vps.update(self.__state_vp()) + if not self.__buffer[self.__cursor] == ',': + break + # move past ',' token + self.__cursor += 1 + self.__remove_whitespace() + self.__remove_whitespace() + return vps + + def __state_string(self): + string = self.__get_word() + self.__remove_whitespace() + return string + + def __get_word(self): + cursor_start = self.__cursor + while self.__cursor < len(self.__buffer): + if (not self.__buffer[self.__cursor].isalnum() + and self.__buffer[self.__cursor] not in ['-', '_']): + return self.__buffer[cursor_start:self.__cursor] + self.__cursor += 1 + return self.__buffer[cursor_start:self.__cursor] + + def __remove_whitespace(self): + while self.__cursor < len(self.__buffer): + if not self.__buffer[self.__cursor].isspace(): + return + self.__cursor += 1 + return diff --git a/pyrad/tools.py b/pyrad/tools.py deleted file mode 100644 index 303eb7a..0000000 --- a/pyrad/tools.py +++ /dev/null @@ -1,255 +0,0 @@ -# tools.py -# -# Utility functions -from ipaddress import IPv4Address, IPv6Address -from ipaddress import IPv4Network, IPv6Network -import struct -import binascii - - -def EncodeString(origstr): - if len(origstr) > 253: - raise ValueError('Can only encode strings of <= 253 characters') - if isinstance(origstr, str): - return origstr.encode('utf-8') - else: - return origstr - - -def EncodeOctets(octetstring): - # Check for max length of the hex encoded with 0x prefix, as a sanity check - if len(octetstring) > 508: - raise ValueError('Can only encode strings of <= 253 characters') - - if isinstance(octetstring, bytes) and octetstring.startswith(b'0x'): - hexstring = octetstring.split(b'0x')[1] - encoded_octets = binascii.unhexlify(hexstring) - elif isinstance(octetstring, str) and octetstring.startswith('0x'): - hexstring = octetstring.split('0x')[1] - encoded_octets = binascii.unhexlify(hexstring) - elif isinstance(octetstring, str) and octetstring.isdecimal(): - encoded_octets = struct.pack('>L',int(octetstring)).lstrip((b'\x00')) - else: - encoded_octets = octetstring - - # Check for the encoded value being longer than 253 chars - if len(encoded_octets) > 253: - raise ValueError('Can only encode strings of <= 253 characters') - - return encoded_octets - - -def EncodeAddress(addr): - if not isinstance(addr, str): - raise TypeError('Address has to be a string') - return IPv4Address(addr).packed - - -def EncodeIPv6Prefix(addr): - if not isinstance(addr, str): - raise TypeError('IPv6 Prefix has to be a string') - ip = IPv6Network(addr) - return struct.pack('2B', *[0, ip.prefixlen]) + ip.ip.packed - - -def EncodeIPv6Address(addr): - if not isinstance(addr, str): - raise TypeError('IPv6 Address has to be a string') - return IPv6Address(addr).packed - - -def EncodeAscendBinary(orig_str): - """ - Format: List of type=value pairs separated by spaces. - - Example: 'family=ipv4 action=discard direction=in dst=10.10.255.254/32' - - Note: redirect(0x20) action is added for http-redirect (walled garden) use case - - Type: - family ipv4(default) or ipv6 - action discard(default) or accept or redirect - direction in(default) or out - src source prefix (default ignore) - dst destination prefix (default ignore) - proto protocol number / next-header number (default ignore) - sport source port (default ignore) - dport destination port (default ignore) - sportq source port qualifier (default 0) - dportq destination port qualifier (default 0) - - Source/Destination Port Qualifier: - 0 no compare - 1 less than - 2 equal to - 3 greater than - 4 not equal to - """ - - terms = { - 'family': b'\x01', - 'action': b'\x00', - 'direction': b'\x01', - 'src': b'\x00\x00\x00\x00', - 'dst': b'\x00\x00\x00\x00', - 'srcl': b'\x00', - 'dstl': b'\x00', - 'proto': b'\x00', - 'sport': b'\x00\x00', - 'dport': b'\x00\x00', - 'sportq': b'\x00', - 'dportq': b'\x00' - } - - family = 'ipv4' - for t in orig_str.split(' '): - key, value = t.split('=') - if key == 'family' and value == 'ipv6': - family = 'ipv6' - terms[key] = b'\x03' - if terms['src'] == b'\x00\x00\x00\x00': - terms['src'] = 16 * b'\x00' - if terms['dst'] == b'\x00\x00\x00\x00': - terms['dst'] = 16 * b'\x00' - elif key == 'action' and value == 'accept': - terms[key] = b'\x01' - elif key == 'action' and value == 'redirect': - terms[key] = b'\x20' - elif key == 'direction' and value == 'out': - terms[key] = b'\x00' - elif key == 'src' or key == 'dst': - if family == 'ipv4': - ip = IPv4Network(value) - else: - ip = IPv6Network(value) - terms[key] = ip.network_address.packed - terms[key+'l'] = struct.pack('B', ip.prefixlen) - elif key == 'sport' or key == 'dport': - terms[key] = struct.pack('!H', int(value)) - elif key == 'sportq' or key == 'dportq' or key == 'proto': - terms[key] = struct.pack('B', int(value)) - - trailer = 8 * b'\x00' - - result = b''.join((terms['family'], terms['action'], terms['direction'], b'\x00', - terms['src'], terms['dst'], terms['srcl'], terms['dstl'], terms['proto'], b'\x00', - terms['sport'], terms['dport'], terms['sportq'], terms['dportq'], b'\x00\x00', trailer)) - return result - - -def EncodeInteger(num, format='!I'): - try: - num = int(num) - except: - raise TypeError('Can not encode non-integer as integer') - return struct.pack(format, num) - - -def EncodeInteger64(num, format='!Q'): - try: - num = int(num) - except: - raise TypeError('Can not encode non-integer as integer64') - return struct.pack(format, num) - - -def EncodeDate(num): - if not isinstance(num, int): - raise TypeError('Can not encode non-integer as date') - return struct.pack('!I', num) - - -def DecodeString(orig_str): - return orig_str.decode('utf-8') - - -def DecodeOctets(orig_bytes): - return orig_bytes - - -def DecodeAddress(addr): - return '.'.join(map(str, struct.unpack('BBBB', addr))) - - -def DecodeIPv6Prefix(addr): - addr = addr + b'\x00' * (18-len(addr)) - _, length, prefix = ':'.join(map('{0:x}'.format, struct.unpack('!BB'+'H'*8, addr))).split(":", 2) - return str(IPv6Network("%s/%s" % (prefix, int(length, 16)))) - - -def DecodeIPv6Address(addr): - addr = addr + b'\x00' * (16-len(addr)) - prefix = ':'.join(map('{0:x}'.format, struct.unpack('!'+'H'*8, addr))) - return str(IPv6Address(prefix)) - - -def DecodeAscendBinary(orig_bytes): - return orig_bytes - - -def DecodeInteger(num, format='!I'): - return (struct.unpack(format, num))[0] - -def DecodeInteger64(num, format='!Q'): - return (struct.unpack(format, num))[0] - -def DecodeDate(num): - return (struct.unpack('!I', num))[0] - - -def EncodeAttr(datatype, value): - if datatype == 'string': - return EncodeString(value) - elif datatype == 'octets': - return EncodeOctets(value) - elif datatype == 'integer': - return EncodeInteger(value) - elif datatype == 'ipaddr': - return EncodeAddress(value) - elif datatype == 'ipv6prefix': - return EncodeIPv6Prefix(value) - elif datatype == 'ipv6addr': - return EncodeIPv6Address(value) - elif datatype == 'abinary': - return EncodeAscendBinary(value) - elif datatype == 'signed': - return EncodeInteger(value, '!i') - elif datatype == 'short': - return EncodeInteger(value, '!H') - elif datatype == 'byte': - return EncodeInteger(value, '!B') - elif datatype == 'date': - return EncodeDate(value) - elif datatype == 'integer64': - return EncodeInteger64(value) - else: - raise ValueError('Unknown attribute type %s' % datatype) - - -def DecodeAttr(datatype, value): - if datatype == 'string': - return DecodeString(value) - elif datatype == 'octets': - return DecodeOctets(value) - elif datatype == 'integer': - return DecodeInteger(value) - elif datatype == 'ipaddr': - return DecodeAddress(value) - elif datatype == 'ipv6prefix': - return DecodeIPv6Prefix(value) - elif datatype == 'ipv6addr': - return DecodeIPv6Address(value) - elif datatype == 'abinary': - return DecodeAscendBinary(value) - elif datatype == 'signed': - return DecodeInteger(value, '!i') - elif datatype == 'short': - return DecodeInteger(value, '!H') - elif datatype == 'byte': - return DecodeInteger(value, '!B') - elif datatype == 'date': - return DecodeDate(value) - elif datatype == 'integer64': - return DecodeInteger64(value) - else: - raise ValueError('Unknown attribute type %s' % datatype) diff --git a/pyrad/utility.py b/pyrad/utility.py new file mode 100644 index 0000000..91e11d6 --- /dev/null +++ b/pyrad/utility.py @@ -0,0 +1,34 @@ +def tlv_name_to_codes(dictionary, tlv): + """ + recursive function to change all the keys in a TLV from strings to + codes + + :param dictionary: dictionary containing attribute name to key mappings + :param tlv: tlv with attribute names + :return: tlv with attribute keys + """ + updated = {} + for key, value in tlv.items(): + code = dictionary.attrindex[key] + + # in nested structures, pyrad stored the entire OID in a single tuple + # but we only want the last code + if isinstance(code, tuple): + code = code[-1] + + if isinstance(value, str): + updated[code] = value + else: + updated[code] = tlv_name_to_codes(dictionary, value) + return updated + + +def vsa_name_to_codes(dictionary, vsa): + updated = {'Vendor-Specific': {}} + + for vendor, tlv in vsa['Vendor-Specific'].items(): + vendor_id = dictionary.vendors[vendor] + vendor_tlv = tlv_name_to_codes(dictionary, tlv) + updated['Vendor-Specific'][vendor_id] = vendor_tlv + + return updated diff --git a/tests/testDatatypes.py b/tests/testDatatypes.py new file mode 100644 index 0000000..c8159fe --- /dev/null +++ b/tests/testDatatypes.py @@ -0,0 +1,98 @@ +from ipaddress import AddressValueError +from pyrad.datatypes.leaf import * +import unittest + + +class LeafEncodingTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.abinary = AscendBinary() + cls.byte = Byte() + cls.date = Date() + cls.ether = Ether() + cls.ifid = Ifid() + cls.integer = Integer() + cls.integer64 = Integer64() + cls.ipaddr = Ipaddr() + cls.ipv6addr = Ipv6addr() + cls.ipv6prefix = Ipv6prefix() + cls.octets = Octets() + cls.short = Short() + cls.signed = Signed() + cls.string = String() + + def testStringEncoding(self): + self.assertRaises(ValueError, self.string.encode, None, 'x' * 254) + self.assertEqual( + self.string.encode(None, '1234567890'), + b'1234567890') + + def testInvalidStringEncodingRaisesTypeError(self): + self.assertRaises(TypeError, self.string.encode, None, 1) + + def testAddressEncoding(self): + self.assertRaises(AddressValueError, self.ipaddr.encode, None,'TEST123') + self.assertEqual( + self.ipaddr.encode(None, '192.168.0.255'), + b'\xc0\xa8\x00\xff') + + def testInvalidAddressEncodingRaisesTypeError(self): + self.assertRaises(TypeError, self.ipaddr.encode, None, 1) + + def testIntegerEncoding(self): + self.assertEqual(self.integer.encode(None, 0x01020304), b'\x01\x02\x03\x04') + + def testInteger64Encoding(self): + self.assertEqual( + self.integer64.encode(None, 0xFFFFFFFFFFFFFFFF), b'\xff' * 8 + ) + + def testUnsignedIntegerEncoding(self): + self.assertEqual(self.integer.encode(None, 0xFFFFFFFF), b'\xff\xff\xff\xff') + + def testInvalidIntegerEncodingRaisesTypeError(self): + self.assertRaises(TypeError, self.integer.encode, None, 'ONE') + + def testDateEncoding(self): + self.assertEqual(self.date.encode(None, 0x01020304), b'\x01\x02\x03\x04') + + def testInvalidDataEncodingRaisesTypeError(self): + self.assertRaises(TypeError, self.date.encode, None, '1') + + def testEncodeAscendBinary(self): + self.assertEqual( + self.abinary.encode(None, 'family=ipv4 action=discard direction=in dst=10.10.255.254/32'), + b'\x01\x00\x01\x00\x00\x00\x00\x00\n\n\xff\xfe\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00') + + def testStringDecoding(self): + self.assertEqual( + self.string.decode(b'1234567890'), + '1234567890') + + def testAddressDecoding(self): + self.assertEqual( + self.ipaddr.decode(b'\xc0\xa8\x00\xff'), + '192.168.0.255') + + def testIntegerDecoding(self): + self.assertEqual( + self.integer.decode(b'\x01\x02\x03\x04'), + 0x01020304) + + def testInteger64Decoding(self): + self.assertEqual( + self.integer64.decode(b'\xff' * 8), 0xFFFFFFFFFFFFFFFF + ) + + def testDateDecoding(self): + self.assertEqual( + self.date.decode(b'\x01\x02\x03\x04'), + 0x01020304) + + def testOctetsEncoding(self): + self.assertEqual(self.octets.encode(None, '0x01020304'), b'\x01\x02\x03\x04') + self.assertEqual(self.octets.encode(None, b'0x01020304'), b'\x01\x02\x03\x04') + self.assertEqual(self.octets.encode(None, '16909060'), b'\x01\x02\x03\x04') + # encodes to 253 bytes + self.assertEqual(self.octets.encode(None, '0x0102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D'), b'\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r') + self.assertRaisesRegex(ValueError, 'Can only encode strings of <= 253 characters', self.octets.encode, None, '0x0102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E') diff --git a/tests/testDictionary.py b/tests/testDictionary.py index 0d1fb99..41b2fd4 100644 --- a/tests/testDictionary.py +++ b/tests/testDictionary.py @@ -7,8 +7,8 @@ from pyrad.dictionary import Attribute from pyrad.dictionary import Dictionary from pyrad.dictionary import ParseError -from pyrad.tools import DecodeAttr from pyrad.dictfile import DictFile +from pyrad.datatypes.leaf import Integer, Integer64, String, Octets class AttributeTests(unittest.TestCase): @@ -19,7 +19,7 @@ def testConstructionParameters(self): attr = Attribute('name', 'code', 'integer', False, 'vendor') self.assertEqual(attr.name, 'name') self.assertEqual(attr.code, 'code') - self.assertEqual(attr.type, 'integer') + self.assertIsInstance(attr.type, Integer) self.assertEqual(attr.is_sub_attribute, False) self.assertEqual(attr.vendor, 'vendor') self.assertEqual(len(attr.values), 0) @@ -30,7 +30,7 @@ def testNamedConstructionParameters(self): vendor='vendor') self.assertEqual(attr.name, 'name') self.assertEqual(attr.code, 'code') - self.assertEqual(attr.type, 'integer') + self.assertIsInstance(attr.type, Integer) self.assertEqual(attr.vendor, 'vendor') self.assertEqual(len(attr.values), 0) @@ -104,7 +104,7 @@ def testParseSimpleDictionary(self): for (attr, code, type) in self.simple_dict_values: attr = self.dict[attr] self.assertEqual(attr.code, code) - self.assertEqual(attr.type, type) + self.assertEqual(attr.type.name, type) def testAttributeTooFewColumnsError(self): try: @@ -168,18 +168,16 @@ def testIntegerValueParsing(self): self.dict.ReadDictionary(StringIO('VALUE Test-Integer Value-Six 5')) self.assertEqual(len(self.dict['Test-Integer'].values), 1) self.assertEqual( - DecodeAttr('integer', - self.dict['Test-Integer'].values['Value-Six']), - 5) + Integer().decode(self.dict['Test-Integer'].values['Value-Six']), + 5) def testInteger64ValueParsing(self): self.assertEqual(len(self.dict['Test-Integer64'].values), 0) self.dict.ReadDictionary(StringIO('VALUE Test-Integer64 Value-Six 5')) self.assertEqual(len(self.dict['Test-Integer64'].values), 1) self.assertEqual( - DecodeAttr('integer64', - self.dict['Test-Integer64'].values['Value-Six']), - 5) + Integer64().decode(self.dict['Test-Integer64'].values['Value-Six']), + 5) def testStringValueParsing(self): self.assertEqual(len(self.dict['Test-String'].values), 0) @@ -187,9 +185,8 @@ def testStringValueParsing(self): 'VALUE Test-String Value-Custard custardpie')) self.assertEqual(len(self.dict['Test-String'].values), 1) self.assertEqual( - DecodeAttr('string', - self.dict['Test-String'].values['Value-Custard']), - 'custardpie') + String().decode(self.dict['Test-String'].values['Value-Custard']), + 'custardpie') def testOctetValueParsing(self): self.assertEqual(len(self.dict['Test-Octets'].values), 0) @@ -199,17 +196,16 @@ def testOctetValueParsing(self): 'VALUE Test-Octets Value-B 0x42\n')) # "B" self.assertEqual(len(self.dict['Test-Octets'].values), 2) self.assertEqual( - DecodeAttr('octets', - self.dict['Test-Octets'].values['Value-A']), - b'A') + Octets().decode(self.dict['Test-Octets'].values['Value-A']), + b'A') self.assertEqual( - DecodeAttr('octets', - self.dict['Test-Octets'].values['Value-B']), - b'B') + Octets().decode(self.dict['Test-Octets'].values['Value-B']), + b'B') def testTlvParsing(self): self.assertEqual(len(self.dict['Test-Tlv'].sub_attributes), 2) - self.assertEqual(self.dict['Test-Tlv'].sub_attributes, {1:'Test-Tlv-Str', 2: 'Test-Tlv-Int'}) + self.assertEqual(self.dict['Test-Tlv'].sub_attributes[1].name, 'Test-Tlv-Str') + self.assertEqual(self.dict['Test-Tlv'].sub_attributes[2].name, 'Test-Tlv-Int') def testSubTlvParsing(self): for (attr, _, _) in self.simple_dict_values: diff --git a/tests/testPacket.py b/tests/testPacket.py index f7649a0..540e32b 100644 --- a/tests/testPacket.py +++ b/tests/testPacket.py @@ -124,7 +124,7 @@ def _create_reply_with_duplicate_attributes(self, request): def _get_attribute_bytes(self, attr_name, value): attr = self.dict.attributes[attr_name] attr_key = attr.code - attr_value = packet.tools.EncodeAttr(attr.type, value) + attr_value = attr.encode(None, value) attr_len = len(attr_value) + 2 return struct.pack('!BB', attr_key, attr_len) + attr_value diff --git a/tests/testTools.py b/tests/testTools.py deleted file mode 100644 index f220e7b..0000000 --- a/tests/testTools.py +++ /dev/null @@ -1,127 +0,0 @@ -from ipaddress import AddressValueError -from pyrad import tools -import unittest - - -class EncodingTests(unittest.TestCase): - def testStringEncoding(self): - self.assertRaises(ValueError, tools.EncodeString, 'x' * 254) - self.assertEqual( - tools.EncodeString('1234567890'), - b'1234567890') - - def testInvalidStringEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeString, 1) - - def testAddressEncoding(self): - self.assertRaises(AddressValueError, tools.EncodeAddress, 'TEST123') - self.assertEqual( - tools.EncodeAddress('192.168.0.255'), - b'\xc0\xa8\x00\xff') - - def testInvalidAddressEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeAddress, 1) - - def testIntegerEncoding(self): - self.assertEqual(tools.EncodeInteger(0x01020304), b'\x01\x02\x03\x04') - - def testInteger64Encoding(self): - self.assertEqual( - tools.EncodeInteger64(0xFFFFFFFFFFFFFFFF), b'\xff' * 8 - ) - - def testUnsignedIntegerEncoding(self): - self.assertEqual(tools.EncodeInteger(0xFFFFFFFF), b'\xff\xff\xff\xff') - - def testInvalidIntegerEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeInteger, 'ONE') - - def testDateEncoding(self): - self.assertEqual(tools.EncodeDate(0x01020304), b'\x01\x02\x03\x04') - - def testInvalidDataEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeDate, '1') - - def testEncodeAscendBinary(self): - self.assertEqual( - tools.EncodeAscendBinary('family=ipv4 action=discard direction=in dst=10.10.255.254/32'), - b'\x01\x00\x01\x00\x00\x00\x00\x00\n\n\xff\xfe\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00') - - def testStringDecoding(self): - self.assertEqual( - tools.DecodeString(b'1234567890'), - '1234567890') - - def testAddressDecoding(self): - self.assertEqual( - tools.DecodeAddress(b'\xc0\xa8\x00\xff'), - '192.168.0.255') - - def testIntegerDecoding(self): - self.assertEqual( - tools.DecodeInteger(b'\x01\x02\x03\x04'), - 0x01020304) - - def testInteger64Decoding(self): - self.assertEqual( - tools.DecodeInteger64(b'\xff' * 8), 0xFFFFFFFFFFFFFFFF - ) - - def testDateDecoding(self): - self.assertEqual( - tools.DecodeDate(b'\x01\x02\x03\x04'), - 0x01020304) - - def testOctetsEncoding(self): - self.assertEqual(tools.EncodeOctets('0x01020304'), b'\x01\x02\x03\x04') - self.assertEqual(tools.EncodeOctets(b'0x01020304'), b'\x01\x02\x03\x04') - self.assertEqual(tools.EncodeOctets('16909060'), b'\x01\x02\x03\x04') - # encodes to 253 bytes - self.assertEqual(tools.EncodeOctets('0x0102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D'), b'\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r') - self.assertRaisesRegex(ValueError, 'Can only encode strings of <= 253 characters', tools.EncodeOctets, '0x0102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E') - - def testUnknownTypeEncoding(self): - self.assertRaises(ValueError, tools.EncodeAttr, 'unknown', None) - - def testUnknownTypeDecoding(self): - self.assertRaises(ValueError, tools.DecodeAttr, 'unknown', None) - - def testEncodeFunction(self): - self.assertEqual( - tools.EncodeAttr('string', 'string'), - b'string') - self.assertEqual( - tools.EncodeAttr('octets', b'string'), - b'string') - self.assertEqual( - tools.EncodeAttr('ipaddr', '192.168.0.255'), - b'\xc0\xa8\x00\xff') - self.assertEqual( - tools.EncodeAttr('integer', 0x01020304), - b'\x01\x02\x03\x04') - self.assertEqual( - tools.EncodeAttr('date', 0x01020304), - b'\x01\x02\x03\x04') - self.assertEqual( - tools.EncodeAttr('integer64', 0xFFFFFFFFFFFFFFFF), - b'\xff'*8) - - def testDecodeFunction(self): - self.assertEqual( - tools.DecodeAttr('string', b'string'), - 'string') - self.assertEqual( - tools.EncodeAttr('octets', b'string'), - b'string') - self.assertEqual( - tools.DecodeAttr('ipaddr', b'\xc0\xa8\x00\xff'), - '192.168.0.255') - self.assertEqual( - tools.DecodeAttr('integer', b'\x01\x02\x03\x04'), - 0x01020304) - self.assertEqual( - tools.DecodeAttr('integer64', b'\xff'*8), - 0xFFFFFFFFFFFFFFFF) - self.assertEqual( - tools.DecodeAttr('date', b'\x01\x02\x03\x04'), - 0x01020304)