diff --git a/pike/__init__.py b/pike/__init__.py index 005ebc93..1fff2898 100644 --- a/pike/__init__.py +++ b/pike/__init__.py @@ -13,5 +13,5 @@ 'test', 'transport', ] -__version_info__ = (0, 2, 14) +__version_info__ = (0, 2, 15) __version__ = "{0}.{1}.{2}".format(*__version_info__) diff --git a/pike/model.py b/pike/model.py index dc65eaf9..b5a3ece0 100644 --- a/pike/model.py +++ b/pike/model.py @@ -477,6 +477,8 @@ def __init__(self, client, server, port=445): self._binding_key = None self._settings = {} self._pre_auth_integrity_hash = array.array('B', "\0"*64) + self._negotiate_request = None + self._negotiate_response = None self.callbacks = {} self.connection_future = Future() self.credits = 0 @@ -549,31 +551,24 @@ def process_callbacks(self, event, obj): for cb in callbacks[ev]: cb(obj) - def smb3_pa_integrity(self, packet, data=None): - """ perform smb3 pre-auth integrity hash update if needed """ - if smb2.DIALECT_SMB3_1_1 not in self.client.dialects: - # hash only applies if client requests 3.1.1 + @property + def negotiate_response(self): + return self._negotiate_response + + @negotiate_response.setter + def negotiate_response(self, value): + if self._negotiate_response is not None: + raise AssertionError("Attempting to overwrite negotiate response") + self._negotiate_response = value + # pre-auth integrity hash processing + if value.dialect_revision < smb2.DIALECT_SMB3_1_1: return - neg_resp = getattr(self, "negotiate_response", None) - if (neg_resp is not None and - neg_resp.dialect_revision < smb2.DIALECT_SMB3_1_1): - # hash only applies if server negotiates 3.1.1 - return - if packet[0].__class__ not in [smb2.NegotiateRequest, - smb2.NegotiateResponse, - smb2.SessionSetupRequest, - smb2.SessionSetupResponse]: - # hash only applies to pre-auth messages - return - if (packet[0].__class__ == smb2.SessionSetupResponse and - packet.status == ntstatus.STATUS_SUCCESS): - # last session setup doesn't count in hash - return - if data is None: - data = packet.serialize() self._pre_auth_integrity_hash = digest.smb3_sha512( self._pre_auth_integrity_hash + - data) + self._negotiate_request.parent.serialize()) + self._pre_auth_integrity_hash = digest.smb3_sha512( + self._pre_auth_integrity_hash + + value.parent.parent.buf[4:]) def next_mid_range(self, length): """ @@ -729,8 +724,6 @@ def _prepare_outgoing(self): if req.is_last_child(): # Last command in chain, ready to send packet - # TODO: move smb pa integrity to callback - self.smb3_pa_integrity(req) result = req.parent.serialize() self.process_callbacks(EV_REQ_POST_SERIALIZE, req.parent) if trace: @@ -777,8 +770,7 @@ def _dispatch_incoming(self, res): ', '.join(f[0].__class__.__name__ for f in res)) self.process_callbacks(EV_RES_POST_DESERIALIZE, res) for smb_res in res: - # TODO: move smb pa integrity and credit tracking to callbacks - self.smb3_pa_integrity(smb_res, smb_res.parent.buf[4:]) + # TODO: move credit tracking to callbacks self.credits += smb_res.credit_response # Verify non-session-setup-response signatures @@ -888,6 +880,7 @@ def negotiate_request(self, hash_algorithms=None, salt=None, ciphers=None): else: preauth_integrity_req.salt = array.array('B', map(random.randint, [0]*32, [255]*32)) + self._negotiate_request = neg_req return neg_req def negotiate_submit(self, negotiate_request): @@ -921,6 +914,7 @@ def __init__(self, conn, creds=None, bind=None, resume=None, self.dialect_revision = conn.negotiate_response.dialect_revision self.bind = bind self.resume = resume + self._pre_auth_integrity_hash = conn._pre_auth_integrity_hash[:] if creds and auth.ntlm is not None: self.auth = auth.NtlmProvider(conn, creds) @@ -958,7 +952,7 @@ def derive_signing_key(self, session_key=None, context=None): session_key = self.session_key if self.dialect_revision >= smb2.DIALECT_SMB3_1_1: if context is None: - context = self.conn._pre_auth_integrity_hash + context = self._pre_auth_integrity_hash return digest.derive_key( session_key, 'SMBSigningKey', @@ -973,7 +967,7 @@ def derive_signing_key(self, session_key=None, context=None): def derive_encryption_keys(self, session_key=None, context=None): if self.dialect_revision >= smb2.DIALECT_SMB3_1_1: if context is None: - context = self.conn._pre_auth_integrity_hash + context = self._pre_auth_integrity_hash for nctx in self.conn.negotiate_response: if isinstance(nctx, crypto.EncryptionCapabilitiesResponse): try: @@ -990,6 +984,21 @@ def derive_encryption_keys(self, session_key=None, context=None): crypto.CryptoKeys300(self.session_key), [crypto.SMB2_AES_128_CCM]) + def _update_pre_auth_integrity(self, packet, data=None): + if smb2.DIALECT_SMB3_1_1 not in self.conn.client.dialects: + # hash only applies if client requests 3.1.1 + return + neg_resp = self.conn.negotiate_response + if (neg_resp is not None and + neg_resp.dialect_revision < smb2.DIALECT_SMB3_1_1): + # hash only applies if server negotiates 3.1.1 + return + if data is None: + data = packet.serialize() + self._pre_auth_integrity_hash = digest.smb3_sha512( + self._pre_auth_integrity_hash + + data) + def _send_session_setup(self, sec_buf): smb_req = self.conn.request() session_req = smb2.SessionSetupRequest(smb_req) @@ -1062,6 +1071,10 @@ def _process(self): self.conn.negotiate_response.security_buffer) elif self.interim_future: + # handle pre-auth integrity on the previous request + previous_request = self.requests[-1] + self._update_pre_auth_integrity(previous_request) + smb_res = self.interim_future.result() self.interim_future = None self.responses.append(smb_res) @@ -1074,6 +1087,7 @@ def _process(self): return self.session_future else: # process interim request + self._update_pre_auth_integrity(smb_res, smb_res.parent.buf[4:]) session_res = smb_res[0] if self.bind: # Need to verify intermediate signatures diff --git a/pike/test/session.py b/pike/test/session.py index fdd01e48..cb09d6c7 100644 --- a/pike/test/session.py +++ b/pike/test/session.py @@ -43,3 +43,20 @@ class SessionTest(pike.test.PikeTest): def test_session_logoff(self): chan, tree = self.tree_connect() chan.logoff() + def test_session_multiplex(self): + chan, tree = self.tree_connect() + chan2 = chan.connection.session_setup(self.creds) + chan3 = chan.connection.session_setup(self.creds) + self.assertEqual(chan.connection, chan2.connection) + self.assertEqual(chan2.connection, chan3.connection) + self.assertNotEqual(chan.session, chan2.session) + self.assertNotEqual(chan2.session, chan3.session) + self.assertNotEqual(chan.session.session_id, chan2.session.session_id) + self.assertNotEqual(chan2.session.session_id, chan3.session.session_id) + self.assertNotEqual(chan.session.session_key, chan2.session.session_key) + self.assertNotEqual(chan2.session.session_key, chan3.session.session_key) + tree2 = chan2.tree_connect(self.share) + tree3 = chan3.tree_connect(self.share) + chan3.logoff() + chan2.logoff() + chan.logoff()