From 2c7bb045c467d012be49d72c0853e136e1838445 Mon Sep 17 00:00:00 2001
From: thez3ro <io@thezero.org>
Date: Thu, 14 Sep 2017 19:03:08 +0200
Subject: [PATCH] permit custom SBox-es

---
 pyaes.py | 26 ++++++++++++++++----------
 1 file changed, 16 insertions(+), 10 deletions(-)

diff --git a/pyaes.py b/pyaes.py
index 5d20c48..75403e7 100644
--- a/pyaes.py
+++ b/pyaes.py
@@ -64,13 +64,13 @@
 key_size = None
 
 
-def new(key, mode, IV=None):
+def new(key, mode, IV=None, Sbox=None, Sboxi=None):
     if mode == MODE_ECB:
-        return ECBMode(AES(key))
+        return ECBMode(AES(key,Sbox,Sboxi))
     elif mode == MODE_CBC:
         if IV is None:
             raise ValueError("CBC mode needs an IV value!")
-        return CBCMode(AES(key), IV)
+        return CBCMode(AES(key,Sbox,Sboxi), IV)
     else:
         raise NotImplementedError
 
@@ -79,7 +79,13 @@ def new(key, mode, IV=None):
 class AES(object):
     block_size = 16
 
-    def __init__(self, key):
+    def __init__(self, key, Sbox=None, Sboxi=None):
+        self.S = Sbox
+        self.Si = Sboxi
+        if Sbox is None:
+            self.S=aes_sbox
+        if Sboxi is None:
+            self.Si=aes_inv_sbox
         self.setkey(key)
 
     def setkey(self, key):
@@ -131,7 +137,7 @@ def expand_key(self):
 
             # apply S-box to all bytes
             for j in xrange(4):
-                word[j] = aes_sbox[word[j]]
+                word[j] = self.S[word[j]]
 
             # apply the Rcon table to the leftmost byte
             word[0] ^= aes_Rcon[i]
@@ -152,7 +158,7 @@ def expand_key(self):
                 for j in xrange(4):
                     # mix in bytes from the last subkey XORed with S-box of
                     # current word bytes
-                    word[j] = aes_sbox[word[j]] ^ exkey[-self.key_size + j]
+                    word[j] = self.S[word[j]] ^ exkey[-self.key_size + j]
                 exkey.extend(word)
 
             # Twice for 192-bit key, thrice for 256-bit key
@@ -268,12 +274,12 @@ def encrypt_block(self, block):
         self.add_round_key(block, 0)
 
         for round in xrange(1, self.rounds):
-            self.sub_bytes(block, aes_sbox)
+            self.sub_bytes(block, self.S)
             self.shift_rows(block)
             self.mix_columns(block)
             self.add_round_key(block, round)
 
-        self.sub_bytes(block, aes_sbox)
+        self.sub_bytes(block, self.S)
         self.shift_rows(block)
         # no mix_columns step in the last round
         self.add_round_key(block, self.rounds)
@@ -288,12 +294,12 @@ def decrypt_block(self, block):
         # count rounds down from (self.rounds) ... 1
         for round in xrange(self.rounds - 1, 0, -1):
             self.shift_rows_inv(block)
-            self.sub_bytes(block, aes_inv_sbox)
+            self.sub_bytes(block, self.Si)
             self.add_round_key(block, round)
             self.mix_columns_inv(block)
 
         self.shift_rows_inv(block)
-        self.sub_bytes(block, aes_inv_sbox)
+        self.sub_bytes(block, self.Si)
         self.add_round_key(block, 0)
         # no mix_columns step in the last round