#!/usr/bin/python

import random
import re

NOP_PERCENT=50

registers=["eax","ebx","ecx","edx","esi","edi"]
unused_registers=[]
asm_lines=[]
x01010101=16843009
x02020202=33686018

def get_unused_registers():
   global unused_registers

   unused_registers.extend(registers)

   for line in asm_lines:
      for reg in unused_registers: 
         if re.search(reg,line) != None:
            unused_registers.remove(reg)


def get_nop():
   nb_choice=3
   if len(unused_registers) > 0:
      nb_choice = nb_choice+2

   choice=random.randint(0,nb_choice) 

   if choice == 0: # NOP
      return "nop"

   elif choice == 1: # push reg, pop reg
      register = random.randint(0,5)
      return "push " + registers[register] + "\npop " + registers[register]

   elif choice == 2: # inc reg, dec reg
      register = random.randint(0,5)
      return "inc " + registers[register] + "\ndec " + registers[register]

   elif choice == 3: # push reg1, mov reg1 <- reg2, push reg1, pop reg2, pop reg1
      register = random.randint(0,5)
      while 1:
         register2 = random.randint(0,5)
         if register != register2:
            break
      return "push " + registers[register] + "\nmov " + registers[register] + "," + registers[register2] + "\npush " + registers[register] + "\npop " + registers[register2]  + "\npop " + registers[register] 

   elif choice == 4: # inc unused_reg
       register = random.randint(0,len(unused_registers)-1)
       return "inc " + unused_registers[register] 

   elif choice == 5: # mov unused_reg, junk nb
       register = random.randint(0,len(unused_registers)-1)
       junk=(random.randint(1,255) << 24 ) + (random.randint(1,255) << 16) + (random.randint(1,255) << 8) + random.randint(1,255)
       return "mov " + unused_registers[register] + "," + str(junk)
        

def get_mov(val1,val2):
   if random.randint(0,1) == 1 and val1 in set(["al","bl","cl","dl"]):
      if val2 not in set(["al","bl","cl","dl"]): # val2 is an integer
         num=0 
         if val2[-1] == "h" or val2[0:2] == "0x" : # hexadecimal value
            if val2[-1] == "h":
               num = int(val2[:-1],16)
            else:
               num = int(val2)
         else: # normal value
            num = int(val2)
         if num == 1:
            return "mov " + val1 + ", 2\ndec " + val1
         if num == 2:
            return "mov " + val1 + ", 1h\ninc " + val1
         else:
            if len(unused_registers) == 0:
               return "mov " + val1 + ", " + str(num - 2) + "\ninc e" + val1[0] + "x\ninc e" + val1[0] + "x"
            else:
               reg=unused_registers[random.randint(0,len(unused_registers)-1)]
               return "push e" + val1[0] + "x\npop " + reg + "\nand " + reg + ", 0xffffff01\nand " + reg + ", 0xffffff02\nadd " + reg + ", " + str(x01010101 + num - 2) + "\ninc " + reg + "\npush " + reg + "\npop e" + val1[0] + "x\nsub e" + val1[0] + "x, " + str(x01010101) + "\ninc e" + val1[0] + "x"

   return "mov " + val1 + ", " + val2


def get_xor(val1,val2):
   if random.randint(0,1) == 1:
      if val1 == val2:
         if random.randint(0,1) == 1:
            return "and " + val1 + ", " + str(x01010101) + "\nand " + val2 + ", " + str(x02020202)
         else:
            return "sub " + val1 + ", " + val1

   return "xor " + val1 + ", " + val2

def get_pop(reg):
   if random.randint(0,1) == 1:
      if len(unused_registers) == 0:
         return "mov " + reg + ", [esp]\nadd esp,0x01010105\nsub esp, " + str(x01010101)
      else:
         ureg=unused_registers[random.randint(0,len(unused_registers)-1)]
         return "pop " + ureg + "\nxor " + reg + ", " + reg + "\nadd " + reg + ", " + ureg 

   return "pop " + reg

def transform():

   nomore=True
   mov=re.compile("mov ([abcdehlispx0-9]+)( *),( *)([abcdehilspx0-9]+)")
   xor=re.compile("xor ([abcdehilspx0-9]+)( *),( *)([abcdehilspx0-9]+)")
   pop=re.compile("pop ([abcdehilspx0-9]+)")

   for line in asm_lines:
      line=line.strip("\n")

      if line.strip(" ") == "sc:": # Ensure no changes are made at the end
         nomore=True

      if nomore == False and line.find("jne") == -1:
         nop=random.randint(0,100)
         if nop <= NOP_PERCENT: # Do we add a junk instruction ?
            print get_nop()

         s = re.search(mov,line)
         if s != None:
             line = get_mov(s.group(1),s.group(4))
         else:
            s = re.search(xor,line)
            if s != None:
               line = get_xor(s.group(1),s.group(4))
            else:
               s = re.search(pop,line)
               if s != None:
                  line = get_pop(s.group(1))
 
      print line # And finally print the line
      if line.strip(" ") == "BITS 32":
         nomore=False


def main():
   import sys
   global asm_lines

   if len(sys.argv) != 2:
      print >> sys.stderr, "Usage: %s <path_to_asm_file>"% (sys.argv[0])
      sys.exit(1)

   f = open(sys.argv[1],"r")
   asm_lines=f.readlines()
   f.close()

   get_unused_registers() # Get (approximatively) which registers are in use and which ones aren't
   random.seed() # Seed the random generator with current epoch
   transform() # Transforme the instructions without changing the semantics

if __name__ == "__main__":
   main()
