diff --git a/examples/shift/byte_shift.py b/examples/shift/byte_shift.py new file mode 100755 index 0000000..fec5a12 --- /dev/null +++ b/examples/shift/byte_shift.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +from lantern.modules import shift +import binascii + + +def shift_bytes(key: int, byte: int) -> int: + """Subtract byte by key""" + return byte - key + + +ciphertext = [0xed, 0xbc, 0xcd, 0xfe] + +KEY = 15 +decryption = shift.decrypt(KEY, ciphertext, shift_bytes) +print(binascii.hexlify(bytearray(decryption))) diff --git a/lantern/modules/shift.py b/lantern/modules/shift.py index 20407c4..b43727b 100644 --- a/lantern/modules/shift.py +++ b/lantern/modules/shift.py @@ -2,11 +2,17 @@ import string +from typing import Callable, Iterable + from lantern import score from lantern.structures import Decryption +ShiftOperator = Callable[[int, int], int] +subtract: ShiftOperator = lambda a, b: a - b +add: ShiftOperator = lambda a, b: a + b + -def make_shift_function(alphabet): +def make_shift_function(alphabet, operator: ShiftOperator=subtract): """Construct a shift function from an alphabet. Examples: @@ -38,14 +44,16 @@ def shift_case_sensitive(shift, symbol): case = case[0] index = case.index(symbol) - return case[(index - shift) % len(case)] + return case[(operator(index, shift)) % len(case)] return shift_case_sensitive -shift_case_english = make_shift_function([string.ascii_uppercase, string.ascii_lowercase]) +shift_decrypt_case_english = make_shift_function([string.ascii_uppercase, string.ascii_lowercase], subtract) +shift_encrypt_case_english = make_shift_function([string.ascii_uppercase, string.ascii_lowercase], add) -def crack(ciphertext, *fitness_functions, min_key=0, max_key=26, shift_function=shift_case_english): + +def crack(ciphertext, *fitness_functions, min_key=0, max_key=26, shift_function=shift_decrypt_case_english): """Break ``ciphertext`` by enumerating keys between ``min_key`` and ``max_key``. Example: @@ -74,20 +82,20 @@ def crack(ciphertext, *fitness_functions, min_key=0, max_key=26, shift_function= decryptions = [] for key in range(min_key, max_key): - plaintext = decrypt(key, ciphertext, shift_function=shift_function) + plaintext = decrypt(key, ciphertext, shift_function) decryptions.append(Decryption(plaintext, key, score(plaintext, *fitness_functions))) return sorted(decryptions, reverse=True) -def decrypt(key, ciphertext, shift_function=shift_case_english): +def decrypt(key, ciphertext, shift_function=shift_decrypt_case_english) -> Iterable: """Decrypt Shift enciphered ``ciphertext`` using ``key``. Examples: >>> ''.join(decrypt(3, "KHOOR")) HELLO - >>> decrypt(15, [0xcf, 0x9e, 0xaf, 0xe0], shift_bytes) + >>> decrypt(15, [0xed, 0xbc, 0xcd, 0xfe], shift_bytes) [0xde, 0xad, 0xbe, 0xef] Args: @@ -96,6 +104,27 @@ def decrypt(key, ciphertext, shift_function=shift_case_english): shift_function (function (shift, symbol)): Shift function to apply to symbols in the ciphertext Returns: - Decrypted ciphertext, list of plaintext symbols + Decrypted text """ return [shift_function(key, symbol) for symbol in ciphertext] + + +def encrypt(key: int, plaintext: Iterable, shift_function=shift_encrypt_case_english) -> Iterable: + """Encrypt ``plaintext`` with ``key`` using the shift cipher. + + Examples: + >>> ''.join(encrypt(3, "HELLO")) + KHOOR + + >>> encrypt(15, [0xde, 0xad, 0xbe, 0xef], shift_bytes) + [0xed, 0xbc, 0xcd, 0xfe] + + Args: + key (int): The shift to use + plaintext (iterable): The symbols to encrypt + shift_function (function (shift, symbol)): Shift function to apply to symbols in the plaintext + + Returns: + Encrypted text + """ + return decrypt(key, plaintext, shift_function) diff --git a/tests/modules/test_shift.py b/tests/modules/test_shift.py index a5d3e80..26b0138 100644 --- a/tests/modules/test_shift.py +++ b/tests/modules/test_shift.py @@ -180,7 +180,13 @@ def test_multi_symbol_decryption(): # We can solve it using this method, or by seperating the shifted letters and decrypting that. def shift_function(shift, symbol): a, b = symbol[:] - print(a) return a + chr(ord(b) - shift) assert ''.join(shift.decrypt(1, ciphertext, shift_function)) == "TEST" + + +def test_encrypt(): + """Test shift encryption""" + plaintext = "HELLO" + + assert ''.join(shift.encrypt(3, plaintext)) == "KHOOR"