Published on

UofTCTF 2024 – Export Grade Cipher

Authors
  • avatar
    Name
    Wrth
    Description
    sleeps way too much

Export Grade Cipher (Crypto, 10 Solves)

This "state of the art"™ cipher can be exported to your enemies without restriction.

Author: nullptr

nc 0.cloud.chals.io 23753

Attachment & official solution

This crypto challenge is a custom cipher, so we can’t really rely on any known attacks. It doesn’t require much complex math knowledge, but the difficulty lies on being able to extract any useful information from this chosen-plaintext attack.

We are given the source code and a module, this is what chal.py looks like:

import ast
import threading
from exportcipher import *
try:
    from flag import FLAG
except:
    FLAG = "test{FLAG}"

MAX_COUNT = 100
TIMEOUT = 120 # seconds

def input_bytes(display_msg):
    m = input(display_msg)
    try:
        m = ast.literal_eval(m)
    except:
        # might not be valid str or bytes literal but could still be valid input, so just encode it
        pass
    if isinstance(m, str):
        m = m.encode()
    assert isinstance(m, bytes)
    return m

def timeout_handler():
    print("Time is up, you can throw out your work as the key changed.")
    exit()

if __name__ == "__main__":
    print("Initializing Export Grade Cipher...")
    key = int.from_bytes(os.urandom(5))
    cipher = ExportGradeCipher(key)
    print("You may choose up to {} plaintext messages to encrypt.".format(MAX_COUNT))
    print("Recover the 40-bit key to get the flag.")
    print("You have {} seconds.".format(TIMEOUT))
    # enough time to crack a 40 bit key with the compute resources of a government
    threading.Timer(TIMEOUT, timeout_handler).start()

    i = 0
    while i < MAX_COUNT:
        pt = input_bytes("[MSG {}] plaintext: ".format(i))
        if not pt:
            break
        if len(pt) > 512:
            # don't allow excessively long messages
            print("Message Too Long!")
            continue
        nonce = os.urandom(256)
        cipher.init_with_nonce(nonce)
        ct = cipher.encrypt(pt)
        print("[MSG {}] nonce: {}".format(i, nonce))
        print("[MSG {}] ciphertext: {}".format(i, ct))
        # sanity check decryption
        cipher.init_with_nonce(nonce)
        assert pt == cipher.decrypt(ct)
        i += 1
    recovered_key = ast.literal_eval(input("Recovered Key: "))
    assert isinstance(recovered_key, int)
    if recovered_key == key:
        print("That is the key! Here is the flag: {}".format(FLAG))
    else:
        print("Wrong!")

So here we can encrypt 100 ciphertexts using ExportGradeCipher, each with a different nonce. Then we have to guess the 40-bit key.

So how does ExportGradeCipher work?

import os

class LFSR:
    def __init__(self, seed, taps, size):
        assert seed != 0
        assert (seed >> size) == 0
        assert len(taps) > 0 and (size - 1) in taps
        self.state = seed
        self.taps = taps
        self.mask = (1 << size) - 1

    def _shift(self):
        feedback = 0
        for tap in self.taps:
            feedback ^= (self.state >> tap) & 1
        self.state = ((self.state << 1) | feedback) & self.mask

    def next_byte(self):
        val = self.state & 0xFF
        for _ in range(8):
            self._shift()
        return val


class ExportGradeCipher:
    def __init__(self, key):
        # 40 bit key
        assert (key >> 40) == 0
        self.key = key
        self.initialized = False

    def init_with_nonce(self, nonce):
        # 256 byte nonce, nonce size isnt export controlled so hopefully this will compensate for the short key size
        assert len(nonce) == 256
        self.lfsr17 = LFSR((self.key & 0xFFFF) | (1 << 16), [2, 9, 10, 11, 14, 16], 17)
        self.lfsr32 = LFSR(((self.key >> 16) | 0xAB << 24) & 0xFFFFFFFF, [1, 6, 16, 21, 23, 24, 25, 26, 30, 31], 32)
        self.S = [i for i in range(256)]
        # Fisher-Yates shuffle S-table
        for i in range(255, 0, -1):
            # generate j s.t. 0 <= j <= i, has modulo bias but good luck exploiting that
            j = (self.lfsr17.next_byte() ^ self.lfsr32.next_byte()) % (i + 1)
            self.S[i], self.S[j] = self.S[j], self.S[i]
        j = 0
        # use nonce to scramble S-table some more
        for i in range(256):
            j = (j + self.lfsr17.next_byte() ^ self.lfsr32.next_byte() + self.S[i] + nonce[i]) % 256
            self.S[i], self.S[j] = self.S[j], self.S[i]
        self.S_inv = [0 for _ in range(256)]
        for i in range(256):
            self.S_inv[self.S[i]] = i
        self.initialized = True

    def _update(self, v):
        i = self.lfsr17.next_byte() ^ self.lfsr32.next_byte()
        self.S[v], self.S[i] = self.S[i], self.S[v]
        self.S_inv[self.S[v]] = v
        self.S_inv[self.S[i]] = i

    def encrypt(self, msg):
        assert self.initialized
        ct = bytes()
        for v in msg:
            ct += self.S[v].to_bytes()
            self._update(v)
        return ct

    def decrypt(self, ct):
        assert self.initialized
        msg = bytes()
        for v in ct:
            vo = self.S_inv[v]
            msg += vo.to_bytes()
            self._update(vo)
        return msg


if __name__ == "__main__":
    cipher = ExportGradeCipher(int.from_bytes(os.urandom(5)))
    nonce = os.urandom(256)
    print("="*50)
    print("Cipher Key: {}".format(cipher.key))
    print("Nonce: {}".format(nonce))
    msg = "..."
    print("="*50)
    print("Plaintext: {}".format(msg))
    cipher.init_with_nonce(nonce)
    ct = cipher.encrypt(msg.encode())
    print("="*50)
    print("Ciphertext: {}".format(ct))
    cipher.init_with_nonce(nonce)
    dec = cipher.decrypt(ct)
    print("="*50)
    try:
        print("Decrypted: {}".format(dec))
        assert msg.encode() == dec
    except:
        print("Decryption failed")

Essentially it generates a very scrambled S-box using the key and nonce, and then does substitution to the plaintext. But to add more confusion we update the S-box for every character that we substitute.

So how do we approach this problem? We can start by analyzing some possible simple attacks first:

Bruteforce?

As noted by the comments as well, it seems very unlikely to bruteforce the whole 40-bit key in reasonable time. However, do take a look at how the key is actually used:

self.lfsr17 = LFSR((self.key & 0xFFFF) | (1 << 16), [2, 9, 10, 11, 14, 16], 17)
self.lfsr32 = LFSR(((self.key >> 16) | 0xAB << 24) & 0xFFFFFFFF, [1, 6, 16, 21, 23, 24, 25, 26, 30, 31], 32)

So it breaks the key into 2 parts, the last 16 bits are used for LFSR17, and the other 24 bits are used for LFSR32.

This means that although we can’t bruteforce the entire key, we might be able to just bruteforce the LFSR17 which is 16 bits of the key in case we need it. This observation alone will not lead us to the solution but it will be used later.

The other thing that we might need to note is that everything here operates using this LFSR, in LFSR we can shift and also unshift the state. This means that if we recover any state of the LSFR, we can unshift them back to the original key.

Recovering LFSR State

So the idea is getting clearer, we want to extract the state of the LFSR, preferably LFSR32 since as we discuss previously we can just perform bruteforcing to find LFSR17. We can analyze this very confusing S-box generating function and we can probably find some crazy relation with the LFSR state maybe?

There is just one problem: IT IS WAY TOO SCRAMBLY

Let’s just cross the init_with_nonce out and suppose it’s secure, there is only one place left where the LFSR is actually used, the _update() function.

_update()

We know how the encryption works, it substitutes the plaintext using the generated S-box and then updates the S-box by swapping the value for the current character with another number produced by the LFSR.

def _update(self, v):
    i = self.lfsr17.next_byte() ^ self.lfsr32.next_byte()
    self.S[v], self.S[i] = self.S[i], self.S[v]
    self.S_inv[self.S[v]] = v
    self.S_inv[self.S[i]] = i

def encrypt(self, msg):
    assert self.initialized
    ct = bytes()
    for v in msg:
        ct += self.S[v].to_bytes()
        self._update(v)
    return ct

This is interesting because it makes it such that the same consecutive characters will have different values. You can try it yourself, it will give different values for each consecutive character.

Image of plaintext "AAAA" mapping to ciphertext with 4 different bytes

But do take a look at the code again, what if v is equal to i?

def _update(self, v):
    i = self.lfsr17.next_byte() ^ self.lfsr32.next_byte()
    self.S[v], self.S[i] = self.S[i], self.S[v]
    self.S_inv[self.S[v]] = v
    self.S_inv[self.S[i]] = i

If v == i, then the S-box will not change, thus the consecutive characters WILL HAVE THE SAME VALUE!

Let’s try to check the i value and give inputs that match it. I updated the script and also doesn’t print the nonce for debug purposes:

def _update(self, v):
    i = self.lfsr17.next_byte() ^ self.lfsr32.next_byte()
    print(f"{hex(i) = }")
    self.S[v], self.S[i] = self.S[i], self.S[v]
    self.S_inv[self.S[v]] = v
    self.S_inv[self.S[i]] = i
Image of output corresponding to _update

So we are right, we can now extract i by trying all to input a consecutive value. This works for any index too:

Image of output for any index

Putting it all together

We know i is just xor of the 2 states of the LFSR, and we can bruteforce for LFSR17, so in order to get the full state of LFSR32, we need 4 consecutive is (each i is 1 byte and LFSR32 is 4 bytes). Because we can only do 100 plaintexts, it’s not guaranteed, but the length limit is not bad (512 bytes) so we have a very decent chance that at least 4 consecutive is among the 512 have a value below 100.

After recovering 4 consecutive i values, we can bruteforce all LFSR17, get the state of LFSR32, and unshift them enough to get the original key.

It needs a little bit of calculation but basically the init_with_nonce function does 511 next_byte per LFSR so it’s 4088 shifts, we also need to add the the starting point from the 512 bytes that we encrypt. 4088 shifts is a lot of work even to bruteforce the LFSR17 so I recommend to do a precalculation before connecting to the remote server.

Full solver:

from wrth import *
from tqdm import tqdm
class LFSR:
    def __init__(self, seed, taps, size):
        assert seed != 0
        assert (seed >> size) == 0
        assert len(taps) > 0 and (size - 1) in taps
        self.state = seed
        self.taps = taps
        self.mask = (1 << size) - 1
        self.size = size

    def _shift(self):
        feedback = 0
        for tap in self.taps:
            feedback ^= (self.state >> tap) & 1
        self.state = ((self.state << 1) | feedback) & self.mask

    def _unshift(self):
        feedback = self.state & 1
        self.state >>= 1
        for tap in self.taps:
            feedback ^= (self.state >> tap) & 1
        self.state |= feedback << (self.size-1)

    def next_byte(self):
        val = self.state & 0xFF
        for _ in range(8):
            self._shift()
        return val

lfsr17precalc = []
print("precalc")
for i in tqdm(range(2**16)):
    lfsr17 = LFSR((i & 0xFFFF) | (1 << 16), [2, 9, 10, 11, 14, 16], 17)
    for _ in range(4088):
        lfsr17._shift()
    lfsr17precalc.append(lfsr17)

r = con("nc 0.cloud.chals.io 23753")
def recover_i():
    res = [-1]*512
    print("generating ciphertext...")
    for i in tqdm(range(100)):
        secret = []
        test = bytes([i]*512)
        r.sendlineafter(b"plaintext: ", str(test))
        r.recvuntil(b"nonce: ")
        nonce = r.recvline()
        r.recvuntil(b"ciphertext: ")
        ct = eval(r.recvline())
        for j in range(512-4):
            if ct[j] == ct[j+1]:
                res[j] = i
                break

    starting = -1
    for i in range(len(res)):
        if res[i] != -1 and res[i+1] != -1 and res[i+2] != -1 and res[i+3] != -1:
            starting = i
            break

    if starting == -1:
        print("run again")
        exit()

    print(f"{starting = }")
    recovered = res[starting:starting+4]
    print(f"{recovered = }")

    return starting, recovered, res

starting, recovered, res = recover_i()

shiftamount = 4088 + (starting)*8
for keylsb in tqdm(range(2**16)):
    lfsr32state = []
    lfsr17 = lfsr17precalc[keylsb]
    for _ in range(starting*8):
        lfsr17._shift()
    for i in range(4):
        lfsr32state.append(lfsr17.next_byte() ^ recovered[i])
    lfsr32state = int.from_bytes(bytes(lfsr32state))
    lfsr32 = LFSR(lfsr32state, [1, 6, 16, 21, 23, 24, 25, 26, 30, 31], 32)
    good = True
    for i in range(4*8):
        lfsr17._unshift()
    for i in range(3*8):
        lfsr32._unshift()
    for i in range(starting, len(res)):
        byte17 = lfsr17.next_byte()
        byte32 = lfsr32.next_byte()
        if res[i] != -1:
            if byte17 ^ byte32 != res[i]:
                good = False
                break
    if good:
        for i in range(starting, len(res)):
            for j in range(8):
                lfsr32._unshift()

        for i in range(shiftamount):
            lfsr32._unshift()

        keymsb = (lfsr32.state & 2**24 - 1)
        print(keymsb)
        print(keylsb)
        key = (keymsb << 16) + keylsb
        print(key)
        r.sendlineafter(b"Key: ",str(key))
        r.interactive()

Flag: uoftctf{wH0_w0u1D_h4ve_7houGHt_l0ng_nONceS_CAnt_S4ve_w3ak_KeYS}