Skip to content

Commit

Permalink
blackify pike/test/*.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jtmoon79 committed Feb 12, 2021
1 parent d5d1166 commit e347aeb
Show file tree
Hide file tree
Showing 25 changed files with 1,923 additions and 1,386 deletions.
151 changes: 96 additions & 55 deletions pike/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CapabilityMissing(TestRequirementNotMet):
class ShareCapabilityMissing(TestRequirementNotMet):
pass


class Options(enum.Enum):
PIKE_LOGLEVEL = "PIKE_LOGLEVEL"
PIKE_TRACE = "PIKE_TRACE"
Expand All @@ -83,22 +84,23 @@ def option(cls, name, default=None):
if value is _NotSet or len(value) == 0:
if default is _Required:
raise MissingArgument(
"Environment variable {!r} must be set".format(name))
"Environment variable {!r} must be set".format(name)
)
value = default
return value

@classmethod
def booloption(cls, name, default='no'):
table = {'yes': True, 'true': True, 'no': False, 'false': False, '': False}
def booloption(cls, name, default="no"):
table = {"yes": True, "true": True, "no": False, "false": False, "": False}
return table[cls.option(name, default=default).lower()]

@classmethod
def smb2constoption(cls, name, default=None):
return getattr(smb2, cls.option(name, '').upper(), default)
return getattr(smb2, cls.option(name, "").upper(), default)

@classmethod
def loglevel(cls):
return getattr(logging, cls.option(cls.PIKE_LOGLEVEL, default='NOTSET').upper())
return getattr(logging, cls.option(cls.PIKE_LOGLEVEL, default="NOTSET").upper())

@classmethod
def trace(cls):
Expand All @@ -110,7 +112,7 @@ def server(cls):

@classmethod
def port(cls):
return int(cls.option(cls.PIKE_PORT, default='445'))
return int(cls.option(cls.PIKE_PORT, default="445"))

@classmethod
def creds(cls):
Expand All @@ -134,7 +136,7 @@ def min_dialect(cls):

@classmethod
def max_dialect(cls):
return cls.smb2constoption(cls.PIKE_MAX_DIALECT, default=float('inf'))
return cls.smb2constoption(cls.PIKE_MAX_DIALECT, default=float("inf"))


def default_client(signing=None):
Expand All @@ -145,8 +147,9 @@ def default_client(signing=None):
if signing is None:
signing = Options.signing()
if signing:
client.security_mode = (smb2.SMB2_NEGOTIATE_SIGNING_ENABLED |
smb2.SMB2_NEGOTIATE_SIGNING_REQUIRED)
client.security_mode = (
smb2.SMB2_NEGOTIATE_SIGNING_ENABLED | smb2.SMB2_NEGOTIATE_SIGNING_REQUIRED
)
return client


Expand All @@ -155,6 +158,7 @@ class TreeConnect(object):
"""
Combines a Client, Connection, Channel, and Tree for simple access to an SMB share.
"""

_client = attr.ib(default=None)
server = attr.ib(factory=Options.server)
port = attr.ib(factory=Options.port)
Expand Down Expand Up @@ -193,21 +197,29 @@ def connect(self):
:return: connected pike.model.Connection
"""
if self.conn and self.conn.connected:
raise SequenceError("Already connected: {!r}. Must call close() before reconnecting".format(self.conn))
raise SequenceError(
"Already connected: {!r}. Must call close() before reconnecting".format(
self.conn
)
)
self.conn = self.client.connect(server=self.server, port=self.port).negotiate()
negotiated_dialect = self.conn.negotiate_response.dialect_revision
if (self.require_dialect and
(negotiated_dialect < self.require_dialect[0] or
negotiated_dialect > self.require_dialect[1])):
if self.require_dialect and (
negotiated_dialect < self.require_dialect[0]
or negotiated_dialect > self.require_dialect[1]
):
self.close()
raise DialectMissing("Dialect required: {}".format(self.require_dialect))

capabilities = self.conn.negotiate_response.capabilities
if (self.require_capabilities and
(capabilities & self.require_capabilities != self.require_capabilities)):
if self.require_capabilities and (
capabilities & self.require_capabilities != self.require_capabilities
):
self.close()
raise CapabilityMissing("Server does not support: %s " %
str(self.require_capabilities & ~capabilities))
raise CapabilityMissing(
"Server does not support: %s "
% str(self.require_capabilities & ~capabilities)
)
return self.conn

def session_setup(self):
Expand All @@ -225,7 +237,11 @@ def session_setup(self):
if not self.conn or not self.conn.connected:
raise SequenceError("Not connected. Must call connect() first")
if self.chan:
raise SequenceError("Channel already established: {!r}. Must call close() before reconnecting".format(self.chan))
raise SequenceError(
"Channel already established: {!r}. Must call close() before reconnecting".format(
self.chan
)
)
self.chan = self.conn.session_setup(self.creds, resume=self.resume)
if self.encryption:
self.chan.session.encrypt_data = True
Expand All @@ -242,17 +258,25 @@ def tree_connect(self):
:return: pike.model.Tree
"""
if not self.chan:
raise SequenceError("Channel not established. Must call session_setup() first")
raise SequenceError(
"Channel not established. Must call session_setup() first"
)
if self.tree:
raise SequenceError("Tree already connected: {!r}. Must call close() before reconnecting".format(self.tree))
raise SequenceError(
"Tree already connected: {!r}. Must call close() before reconnecting".format(
self.tree
)
)
self.tree = self.chan.tree_connect(self.share)
capabilities = self.tree.tree_connect_response.capabilities
if (self.require_share_capabilities and
(capabilities & self.require_share_capabilities != self.require_share_capabilities)):
if self.require_share_capabilities and (
capabilities & self.require_share_capabilities
!= self.require_share_capabilities
):
self.close()
raise ShareCapabilityMissing(
"Share does not support: %s" %
str(self.require_share_capabilities & ~capabilities)
"Share does not support: %s"
% str(self.require_share_capabilities & ~capabilities)
)
return self.tree

Expand Down Expand Up @@ -355,8 +379,10 @@ def init_once():
PikeTest.loglevel = Options.loglevel()
PikeTest.handler = logging.StreamHandler()
PikeTest.handler.setLevel(PikeTest.loglevel)
PikeTest.handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s: %(message)s'))
PikeTest.logger = logging.getLogger('pike')
PikeTest.handler.setFormatter(
logging.Formatter("%(asctime)s:%(name)s:%(levelname)s: %(message)s")
)
PikeTest.logger = logging.getLogger("pike")
PikeTest.logger.addHandler(PikeTest.handler)
PikeTest.logger.setLevel(PikeTest.loglevel)
model.trace = PikeTest.trace = Options.trace()
Expand Down Expand Up @@ -391,8 +417,7 @@ def error(self, *args, **kwargs):
def critical(self, *args, **kwargs):
self.logger.critical(*args, **kwargs)

def set_client_dialect(self, min_dialect=None, max_dialect=None,
client=None):
def set_client_dialect(self, min_dialect=None, max_dialect=None, client=None):
if client is None:
client = self.default_client
client.restrict_dialects(min_dialect, max_dialect)
Expand Down Expand Up @@ -434,20 +459,21 @@ def assert_error(self, status):
if err.response.status != status:
raise_from(
self.failureException(
'"%s" raised when "%s" expected' % (err.response.status, status),
'"%s" raised when "%s" expected'
% (err.response.status, status),
),
err
err,
)

def setUp(self):
if self.loglevel != logging.NOTSET:
print(file=sys.stderr)

if hasattr(self, 'setup'):
if hasattr(self, "setup"):
self.setup()

def tearDown(self):
if hasattr(self, 'teardown'):
if hasattr(self, "teardown"):
self.teardown()

for conn in self._connections:
Expand All @@ -457,7 +483,7 @@ def tearDown(self):
gc.collect()

def _get_decorator_attr(self, name, default):
name = '__pike_test_' + name
name = "__pike_test_" + name
test_method = getattr(self, self._testMethodName)

if hasattr(test_method, name):
Expand All @@ -468,13 +494,13 @@ def _get_decorator_attr(self, name, default):
return default

def required_dialect(self):
return self._get_decorator_attr('RequireDialect', (0, float('inf')))
return self._get_decorator_attr("RequireDialect", (0, float("inf")))

def required_capabilities(self):
return self._get_decorator_attr('RequireCapabilities', 0)
return self._get_decorator_attr("RequireCapabilities", 0)

def required_share_capabilities(self):
return self._get_decorator_attr('RequireShareCapabilities', 0)
return self._get_decorator_attr("RequireShareCapabilities", 0)

def assertBufferEqual(self, buf1, buf2):
"""
Expand All @@ -489,64 +515,79 @@ def assertBufferEqual(self, buf1, buf2):
# XXX: consider usage of stdlib bisect module
chunk_1 = (low, low + ((high - low) // 2))
chunk_2 = (chunk_1[1], high)
if buf1[chunk_1[0]:chunk_1[1]] != buf2[chunk_1[0]:chunk_1[1]]:
if buf1[chunk_1[0] : chunk_1[1]] != buf2[chunk_1[0] : chunk_1[1]]:
low, high = chunk_1
elif buf1[chunk_2[0]:chunk_2[1]] != buf2[chunk_2[0]:chunk_2[1]]:
elif buf1[chunk_2[0] : chunk_2[1]] != buf2[chunk_2[0] : chunk_2[1]]:
low, high = chunk_2
else:
break
if high - low <= 1:
raise AssertionError("Block mismatch at byte {0}: "
"{1} != {2}".format(low, buf1[low], buf2[low]))
raise AssertionError(
"Block mismatch at byte {0}: "
"{1} != {2}".format(low, buf1[low], buf2[low])
)


class _Decorator(object):
def __init__(self, value):
self.value = value

def __call__(self, thing):
setattr(thing, '__pike_test_' + self.__class__.__name__, self.value)
setattr(thing, "__pike_test_" + self.__class__.__name__, self.value)
return thing


class _RangeDecorator(object):
def __init__(self, minvalue=0, maxvalue=float('inf')):
def __init__(self, minvalue=0, maxvalue=float("inf")):
self.minvalue = minvalue
self.maxvalue = maxvalue

def __call__(self, thing):
setattr(thing, '__pike_test_' + self.__class__.__name__,
(self.minvalue, self.maxvalue))
setattr(
thing,
"__pike_test_" + self.__class__.__name__,
(self.minvalue, self.maxvalue),
)
return thing


class RequireDialect(_RangeDecorator): pass
class RequireCapabilities(_Decorator): pass
class RequireShareCapabilities(_Decorator): pass
class RequireDialect(_RangeDecorator):
pass


class RequireCapabilities(_Decorator):
pass


class RequireShareCapabilities(_Decorator):
pass


class PikeTestSuite(unittest.TestSuite):
"""
Custom test suite for easily patching in skip tests in downstream
distributions of these test cases
"""

skip_tests_reasons = {
"test_to_be_skipped": "This test should be skipped",
"test_to_be_skipped": "This test should be skipped",
}

@staticmethod
def _raise_skip(reason):
def inner(*args, **kwds):
raise unittest.SkipTest(reason)

return inner

def addTest(self, test):
testMethodName = getattr(test, "_testMethodName", None)
if testMethodName in self.skip_tests_reasons:
setattr(
test,
testMethodName,
self._raise_skip(
self.skip_tests_reasons[testMethodName]))
test,
testMethodName,
self._raise_skip(self.skip_tests_reasons[testMethodName]),
)
super(PikeTestSuite, self).addTest(test)


Expand Down Expand Up @@ -611,15 +652,15 @@ def suite(clz=PikeTestSuite):
test_loader = unittest.TestLoader()
test_loader.suiteClass = clz
test_suite = test_loader.discover(
os.path.abspath(os.path.dirname(__file__)),
"*.py")
os.path.abspath(os.path.dirname(__file__)), "*.py"
)
return test_suite


def samba_suite():
return suite(SambaPikeTestSuite)


if __name__ == '__main__':
if __name__ == "__main__":
test_runner = unittest.TextTestRunner()
test_runner.run(suite())
Loading

0 comments on commit e347aeb

Please sign in to comment.