Everything-is-wrong

Writing a debugger to side-channel out comparisons with RAX values

Problem overview

When running the problem, we're presented with a map of the world and an input prompt. The binary takes our input and checks it, and then prints out that we entered the incorrect flag.

I opened the binary in Binary Ninja and saw a function called Main_validateFlag_info that looked like the below image, which immediately made me want to see if I could side-channel an answer instead of trying to reverse Haskell.

My normal side-channel tools utilize qemu usermode emulation, but I noticed that the number of instructions executed for the same input varied. Below I'm sending in just the letter "A" and seeing a different number of instructions executed every time. This told me that I might be able to use an instruction counting side-channel but it might be unstable.

  echo "A" | qemu-x86_64 -d in_asm,nochain ./l3  2>&1 | wc -l
38761
  echo "A" | qemu-x86_64 -d in_asm,nochain ./l3  2>&1 | wc -l 
38651
  echo "A" | qemu-x86_64 -d in_asm,nochain ./l3  2>&1 | wc -l
38664
  echo "A" | qemu-x86_64 -d in_asm,nochain ./l3  2>&1 | wc -l
38756

I wanted to know the flag length, and can usually just add "A"'s to start getting more instructions, but again I wasn't seeing constant (only variable) instruction increases.

  spaceheroes2 echo "A" | qemu-x86_64 -d in_asm,nochain ./l3  2>&1 | wc -l
38756
  spaceheroes2 echo "AA" | qemu-x86_64 -d in_asm,nochain ./l3  2>&1 | wc -l
38726
  spaceheroes2 echo "AAA" | qemu-x86_64 -d in_asm,nochain ./l3  2>&1 | wc -l
38821
  spaceheroes2 echo "AAAA" | qemu-x86_64 -d in_asm,nochain ./l3  2>&1 | wc -l
38804

Knowing how unstable instruction counting was looking I wanted to find an alternate side-channel that would give me a clearer picture of how many successful instructions or calls were executed.

Here is where the RAX==0 side-channel comes in

Checking for RAX == 0

RAX is the return or result register for amd64. Most return status codes use 0 to indicate that they were successful. The thinking here is that if we count the number of times RAX is equal to 0, we can access a side-channel indicting general program success without worrying about instructions or threads/timers messing up our side-channel!

I started writing a tracing utility a little while to hook up to another process use PTRACE and hunt for call instructions and dump arguments. I decided to add a little bit more to this tracer and added register comparisons. The premise of the utility is two fold:

  • fork process from tracer and set PTRACE_TRACEME, then execve the program

  • PTRACE_SINGLESTEP the process, PTRACE_GETREGS then count and compare regs.RAX == 0

You can access the call tracer here: https://github.com/ChrisTheCoolHut/call_trace and after compiling should be able to run it with:

echo "A" | NO_CALL_TRACE=1 ./build/call_trace ./l3
   .... SNIP ....
.slaitnederc CTF yfireV
Incorrect, chthonite

258499 : ins executed
47488 : times RAX == 0
39929 : times RCX == 0
51831 : times RDX == 0
29301 : times RDI == 0
70833 : times RSI == 0

Getting the flag length

Once the tracer was printing out the RAX comparison values I was able to start printing out the different number of "A"s to see what all the lengths resulted in. I use the python2 -c trick to print bytes from a bash variable i. I looped it and grepped for that.

$ for i in $(seq 1 30);do echo $i; python2 -c "print 'A'*$i" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RAX;done
1
46813 : times RAX == 0
2
46823 : times RAX == 0
3
46819 : times RAX == 0
4
46823 : times RAX == 0
5
46812 : times RAX == 0
6
46809 : times RAX == 0
7
46806 : times RAX == 0
8
46805 : times RAX == 0
9
46808 : times RAX == 0
10
46811 : times RAX == 0
11
46808 : times RAX == 0
12
46809 : times RAX == 0
13
46803 : times RAX == 0
14
46819 : times RAX == 0
15
46820 : times RAX == 0
16
46807 : times RAX == 0
17
46830 : times RAX == 0
18
46829 : times RAX == 0
19
46815 : times RAX == 0
20
46813 : times RAX == 0
21
46814 : times RAX == 0
22
47083 : times RAX == 0 #<--------------------- HERE is the increase
23
46815 : times RAX == 0
24
46813 : times RAX == 0
25
46812 : times RAX == 0
26
46813 : times RAX == 0
27
46814 : times RAX == 0
28
46814 : times RAX == 0
29
46810 : times RAX == 0
30
46816 : times RAX == 0

I actually learned at this point that almost EVERY register provided a side-channel and showed a large increase in registers equal to 0:

$ for i in $(seq 19 25);do echo $i; python2 -c "print 'A'*$i" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RCX;done
19
39604 : times RCX == 0
20
39621 : times RCX == 0
21
39633 : times RCX == 0
22
40172 : times RCX == 0 #<--- heres the jump
23
39662 : times RCX == 0
24
39677 : times RCX == 0
25
39693 : times RCX == 0

$ for i in $(seq 19 25);do echo $i; python2 -c "print 'A'*$i" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RDI;done
19
29505 : times RDI == 0
20
29524 : times RDI == 0
21
29530 : times RDI == 0
22
29661 : times RDI == 0 #<--- heres the jump
23
29567 : times RDI == 0
24
29578 : times RDI == 0
25
29596 : times RDI == 0

$ for i in $(seq 19 25);do echo $i; python2 -c "print 'A'*$i" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RSI;done
19
70220 : times RSI == 0
20
70215 : times RSI == 0
21
70217 : times RSI == 0
22
70665 : times RSI == 0 #<--- heres the jump
23
70216 : times RSI == 0
24
70221 : times RSI == 0
25
70220 : times RSI == 0

At this point we have multiple sources showing a jump in registers == 0 at input length 22 so I knew that flag length must be 22 characters long.

How to build a side-channel solve

The next step is to find out how the problem is checking those input values. Usually these problems compare each value left to right against the hard-coded flag value, but sometimes they'll check certain pieces of the input first.

I needed to see if it started by checking the first character against the flag. The organizers shared that the flag format was shctf{} so we know the first character MUST be s. So we can run our tracer again with all "A"s and with all "A"s, but an "s" at the beginning:

$ python2 -c "print 'A' + 'A'*21" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RAX
47081 : times RAX == 0 # <----------- Known bad input for first character
$ python2 -c "print 's' + 'A'*21" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RAX
47835 : times RAX == 0 # <----------- Known good input for first character
$ python2 -c "print 'sh' + 'A'*20" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RAX
49596 : times RAX == 0 # <----------- Known good input for second character
$ python2 -c "print 'shc' + 'A'*19" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RAX
52953 : times RAX == 0 # <----------- Known good input for third character

At this point I can tell that we're getting our input checked byte by byte, and each successful character drastically increases the number of times RAX is equal to 0.

The next step is start toward writing a solver is to wrap these calls and check the tracer outputs. I always start my solve scripts with a check method that sends a given input into the tool and comes back with the expected output:

cmd = '''
echo '{}' | NO_CALL_TRACE=1 /home/chris/projects/pwn_trace/build/call_trace ./l3 | grep -iE 'RAX'
'''

def check(guess, results):
    n_cmd = cmd.format(guess)
    output = None
    try:
        output = subprocess.check_output(n_cmd, shell=True)
    except:
        pass
    if output:
        num = output.split(b":")[0]
        num = int(num)
        results[num] = guess

Next we need to be able to modify inputs to our guesses, the snippet below takes in a guess, position, and character and substitutes it in.

def mod_input(user_input, position, character):
    user_input = list(user_input)
    user_input[position] = character
    return "".join(user_input)

Once both those two steps are done, all that's left it to try every letter for each position and check each number paired with RAX==0 counts to see which one has the greatest number of RAX==0 counts.

Since the majority of time spent processing here is during the call_trace, I'm using python threading to start a thread per character and have them join once we have every value

def get_max(my_dict):
    # Get the biggest key from the dictionary
    big_num = max(k for k, v in my_dict.items())

    print(my_dict[big_num])
    return my_dict[big_num]

# Iterator over printable letters
iter_range = range(0x21,0x7E)
iter_range = [chr(x) for x in iter_range]

results = {}
for i in range(len(curr_flag)):
    print(i)
    results = {}
    threads = []
    for x in iter_range:
        curr = mod_input(curr_flag, i, x)
        t = threading.Thread(target=check, args=[curr, results])
        t.start()
        threads.append(t)
    for thread in threads:
        thread.join()
    # Modify our current guess to have the best letter
    curr_flag = get_max(results)

The script will take a while, but here is what it looks like running:

$ python3 solve_l3.py
0
sAAAAAAAAAAAAAAAAAAAAA
1
shAAAAAAAAAAAAAAAAAAAA
   .... SNIP .....
21
shctf{gn0rw_ll4_s1_t1A
22
shctf{gn0rw_ll4_s1_t1}

Flag

shctf{gn0rw_ll4_s1_t1}

Instruction Counting

I talked with the challenge author and learned the instruction counting side-channel was stable enough to get both the flag length and compare input values. ( Here is his solve https://github.com/FITSEC/spaceheroes_ctf_23/blob/main/RE/Everything-is-wrong/solve.py)

Once I learned that I tried running Instruction Stomp, and it was able to find that flag too!

https://github.com/ChrisTheCoolHut/Instruction-Stomp

$ python3 InstStomp.py -i 22 --stdin ~/ctf/spaceheroes2/l3

Solve script

import string
import subprocess
import threading

curr_flag = "AAAAAAAAAAAAAAAAAAAAAA"
FLAG_LEN = 22
cmd = '''
echo '{}' | /home/chris/projects/pwn_trace/build/call_trace ./l3 | grep -iE 'RAX'
'''

def check(guess, results):
    n_cmd = cmd.format(guess)
    output = None
    try:
        output = subprocess.check_output(n_cmd, shell=True)
    except:
        pass
    if output:
        num = output.split(b":")[0]
        num = int(num)
        results[num] = guess
    
def mod_input(user_input, position, character):
    user_input = list(user_input)
    user_input[position] = character
    return "".join(user_input)

def get_max(my_dict):
    big_num = max(k for k, v in my_dict.items())

    print(my_dict[big_num])
    return my_dict[big_num]

iter_range = range(0x21,0x7E)
iter_range = [chr(x) for x in iter_range]
iter_range = string.ascii_letters

results = {}
for i in range(len(curr_flag)):
    print(i)
    results = {}
    threads = []
    for x in iter_range:
        curr = mod_input(curr_flag, i, x)
        t = threading.Thread(target=check, args=[curr, results])
        t.start()
        threads.append(t)
    for thread in threads:
        thread.join()
    curr_flag = get_max(results)

Last updated