When running the problem, we're presented with a map of the world and an input prompt. The binary takes our input and checks it, and then prints out that we entered the incorrect flag.
I opened the binary in Binary Ninja and saw a function called Main_validateFlag_info that looked like the below image, which immediately made me want to see if I could side-channel an answer instead of trying to reverse Haskell.
My normal side-channel tools utilize qemu usermode emulation, but I noticed that the number of instructions executed for the same input varied. Below I'm sending in just the letter "A" and seeing a different number of instructions executed every time. This told me that I might be able to use an instruction counting side-channel but it might be unstable.
I wanted to know the flag length, and can usually just add "A"'s to start getting more instructions, but again I wasn't seeing constant (only variable) instruction increases.
Knowing how unstable instruction counting was looking I wanted to find an alternate side-channel that would give me a clearer picture of how many successful instructions or calls were executed.
Here is where the RAX==0 side-channel comes in
Checking for RAX == 0
RAX is the return or result register for amd64. Most return status codes use 0 to indicate that they were successful. The thinking here is that if we count the number of times RAX is equal to 0, we can access a side-channel indicting general program success without worrying about instructions or threads/timers messing up our side-channel!
I started writing a tracing utility a little while to hook up to another process use PTRACE and hunt for call instructions and dump arguments. I decided to add a little bit more to this tracer and added register comparisons. The premise of the utility is two fold:
fork process from tracer and set PTRACE_TRACEME, then execve the program
PTRACE_SINGLESTEP the process, PTRACE_GETREGS then count and compare regs.RAX == 0
Once the tracer was printing out the RAX comparison values I was able to start printing out the different number of "A"s to see what all the lengths resulted in. I use the python2 -c trick to print bytes from a bash variable i. I looped it and grepped for that.
$ for i in $(seq 1 30);do echo $i; python2 -c "print 'A'*$i" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RAX;done
1
46813 : times RAX == 0
2
46823 : times RAX == 0
3
46819 : times RAX == 0
4
46823 : times RAX == 0
5
46812 : times RAX == 0
6
46809 : times RAX == 0
7
46806 : times RAX == 0
8
46805 : times RAX == 0
9
46808 : times RAX == 0
10
46811 : times RAX == 0
11
46808 : times RAX == 0
12
46809 : times RAX == 0
13
46803 : times RAX == 0
14
46819 : times RAX == 0
15
46820 : times RAX == 0
16
46807 : times RAX == 0
17
46830 : times RAX == 0
18
46829 : times RAX == 0
19
46815 : times RAX == 0
20
46813 : times RAX == 0
21
46814 : times RAX == 0
22
47083 : times RAX == 0 #<--------------------- HERE is the increase
23
46815 : times RAX == 0
24
46813 : times RAX == 0
25
46812 : times RAX == 0
26
46813 : times RAX == 0
27
46814 : times RAX == 0
28
46814 : times RAX == 0
29
46810 : times RAX == 0
30
46816 : times RAX == 0
I actually learned at this point that almost EVERY register provided a side-channel and showed a large increase in registers equal to 0:
$ for i in $(seq 19 25);do echo $i; python2 -c "print 'A'*$i" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RCX;done
19
39604 : times RCX == 0
20
39621 : times RCX == 0
21
39633 : times RCX == 0
22
40172 : times RCX == 0 #<--- heres the jump
23
39662 : times RCX == 0
24
39677 : times RCX == 0
25
39693 : times RCX == 0
$ for i in $(seq 19 25);do echo $i; python2 -c "print 'A'*$i" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RDI;done
19
29505 : times RDI == 0
20
29524 : times RDI == 0
21
29530 : times RDI == 0
22
29661 : times RDI == 0 #<--- heres the jump
23
29567 : times RDI == 0
24
29578 : times RDI == 0
25
29596 : times RDI == 0
$ for i in $(seq 19 25);do echo $i; python2 -c "print 'A'*$i" | NO_CALL_TRACE=1 ./build/call_trace ./l3 | grep RSI;done
19
70220 : times RSI == 0
20
70215 : times RSI == 0
21
70217 : times RSI == 0
22
70665 : times RSI == 0 #<--- heres the jump
23
70216 : times RSI == 0
24
70221 : times RSI == 0
25
70220 : times RSI == 0
At this point we have multiple sources showing a jump in registers == 0 at input length 22 so I knew that flag length must be 22 characters long.
How to build a side-channel solve
The next step is to find out how the problem is checking those input values. Usually these problems compare each value left to right against the hard-coded flag value, but sometimes they'll check certain pieces of the input first.
I needed to see if it started by checking the first character against the flag. The organizers shared that the flag format was shctf{} so we know the first character MUST be s. So we can run our tracer again with all "A"s and with all "A"s, but an "s" at the beginning:
$python2-c"print 'A' + 'A'*21"| NO_CALL_TRACE=1./build/call_trace./l3|grepRAX47081:timesRAX==0# <----------- Known bad input for first character$python2-c"print 's' + 'A'*21"| NO_CALL_TRACE=1./build/call_trace./l3|grepRAX47835:timesRAX==0# <----------- Known good input for first character$python2-c"print 'sh' + 'A'*20"| NO_CALL_TRACE=1./build/call_trace./l3|grepRAX49596:timesRAX==0# <----------- Known good input for second character$python2-c"print 'shc' + 'A'*19"| NO_CALL_TRACE=1./build/call_trace./l3|grepRAX52953:timesRAX==0# <----------- Known good input for third character
At this point I can tell that we're getting our input checked byte by byte, and each successful character drastically increases the number of times RAX is equal to 0.
The next step is start toward writing a solver is to wrap these calls and check the tracer outputs. I always start my solve scripts with a check method that sends a given input into the tool and comes back with the expected output:
Once both those two steps are done, all that's left it to try every letter for each position and check each number paired with RAX==0 counts to see which one has the greatest number of RAX==0 counts.
Since the majority of time spent processing here is during the call_trace, I'm using python threading to start a thread per character and have them join once we have every value
defget_max(my_dict):# Get the biggest key from the dictionary big_num =max(k for k, v in my_dict.items())print(my_dict[big_num])return my_dict[big_num]# Iterator over printable lettersiter_range =range(0x21,0x7E)iter_range = [chr(x)for x in iter_range]results ={}for i inrange(len(curr_flag)):print(i) results ={} threads = []for x in iter_range: curr =mod_input(curr_flag, i, x) t = threading.Thread(target=check, args=[curr, results]) t.start() threads.append(t)for thread in threads: thread.join()# Modify our current guess to have the best letter curr_flag =get_max(results)
The script will take a while, but here is what it looks like running: