#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Martin Pitt <martin@piware.de>
# Based on original code of Thomas Pani <thomas.pani/at/gmail.com>

import socket
import struct
import sys
import math

HOST = '127.0.0.1'
PORT = 1979

PACKET_SIZE=1026

BULLET_SPEED=8
NEAR_TRESHOLD = 100**2 # below this distance, objects are considered as 'near'

invalid_packets = 0 # number of invalid packets in succession

def dist(x1, y1, x2, y2):
    '''Return (dx, dy) distance between two points, taking into account edge
    wrapping.'''

    dx = x1 - x2
    while dx < -512: dx += 1024
    while dx > 511: dx -= 1024
    dy = y1 - y2
    while dy < -384: dy += 768
    while dy > 383: dy -= 768
    return (dx, dy)

class Object:
    '''Abstract representation of a screen object.'''

    def __init__(self):
        '''Initialize screen object.'''
        self.x = None
        self.y = None
        self.vx = None
        self.vy = None

        self.radius = 0

    def move(self, x, y):
        '''Set new position and update speed.
        
        The speed is computed with a very simple IIR low-pass filter.
        
        TODO: adapt responsiveness with #previous values'''
        if self.x is not None:
            (dx, dy) = dist(x, y, self.x, self.y)

            if self.vx is not None:
                self.vx = self.vx * 0.3 + dx * 0.7
                self.vy = self.vy * 0.3 + dy * 0.7
            else:
                self.vx = dx
                self.vy = dy
        self.x = x
        self.y = y

    def next_pos(self):
        '''Estimate the next position based on the current position/speed.'''

        if self.vx is not None:
            return (int(self.x + self.vx + .5), int(self.y + self.vy + .5))
        else:
            return (self.x, self.y)

    def dist_radius(self, o):
        '''Calculate distance to another object.
        
        This returns a tuple (dx, dy) of the minimal distance, taking edge
        wrapping and radiuses into account.'''

        (dx, dy) = dist(o.x, o.y, self.x, self.y)

        if dx > 0:
            dx -= self.radius + o.radius
        else:
            dx += self.radius + o.radius
        if dy > 0:
            dy -= self.radius + o.radius
        else:
            dy += self.radius + o.radius
        return (dx, dy)

    def next_dist_to_point(self, x, y):
        '''Calculate estimated distance to (x, y) in the next frame.
        
        This returns a tuple (dx, dy) of the minimal distance, taking edge
        wrapping and current speed into account (but no radius).
        
        This is for center point prediction.'''

        (nx, ny) = self.next_pos()
        return dist (x, y, nx, ny)

    def speed(self):
        'Return squared speed, or 0 if unknown.'''

        if self.vx is None:
            return 0
        return self.vx*self.vx + self.vy*self.vy

    def collides(self, o):
        '''Return collision time with another object.

        Return None if the objects will not collide with their current
        position and velocity.'''

        if self.vx is None or o.vx is None:
            return None

        # relative distance
        (dx, dy) = dist(self.x, self.y, o.x, o.y)
        d2 = dx*dx + dy*dy

        # relative speed
        vx = self.vx - o.vx
        vy = self.vy - o.vy
        v2 = vx*vx + vy*vy

        if v2 == 0:
            return None

        # solve quadratic equation
        p = float(-dx*vx - dy*vy)/v2
        q = float(d2 - (self.radius + o.radius)**2)/v2
        discr = p*p - q

        if discr < 0:
            return None
        discr = math.sqrt(discr)
        if p > discr:
            return int(p - discr + 0.5)
        else:
            return int(p + discr + 0.5)

class Asteroid(Object):
    def __init__(self, type, sf):
        Object.__init__(self)
        self.type = type        # 1..4, äußere Form
        self.sf = sf        # 0 = groß, 15 = mittel, 14 = klein

        if sf == 0:        # grosser
            self.radius = 40
        elif sf == 15: # mittlerer
            self.radius = 20
        elif sf == 14:        # kleiner
            self.radius = 8

    def __str__(self):
        if self.vx is not None:
            return 'Asteroid %i/%i: X(%i, %i), V(%0.3f, %0.3f)' % (self.type, self.sf,
                self.x, self.y, self.vx, self.vy)
        else:
            return 'Asteroid %i/%i: X(%i, %i)' % (self.type, self.sf, self.x, self.y)

class Shot(Object):
    def __init__(self):
        Object.__init__(self)

    def __str__(self):
        if self.vx is not None:
            return 'Shot: X(%i, %i), V(%.3f, %.3f)' % (self.x, self.y, self.vx, self.vy)
        else:
            return 'Shot: X(%i, %i)' % (self.x, self.y)

class Ship(Object):
    '''Player ship'''

    def __init__(self):
        Object.__init__(self)
        self.lookx = None
        self.looky = None

    def __str__(self):
        if self.vx is not None:
            return 'Ship: X(%i, %i), V(%.3f, %.3f), L(%i, %i)' % (self.x, self.y,
                self.vx, self.vy, self.lookx, self.looky)
        else:
            return 'Ship: X(%i, %i), D(%i, %i)' % (self.x, self.y, self.lookx, self.looky)

    def hit_direction(self, obj):
        '''Return a direction (hit_x, hit_y) for hitting obj.'''

        (deltax, deltay) = dist(obj.x, obj.y, self.x, self.y)

        if self.vx is None or obj.vx is None:
            # best we can do
            return (deltax, deltay)

        # relative speed
        vx = self.vx - obj.vx
        vy = self.vy - obj.vy

        # special case division by (near) zero
        if abs(deltay) <= 5:
            hity = float(-vy)/BULLET_SPEED
            try:
                hitx = math.sqrt(1 - hity*hity) # ±
            except ValueError:
                hitx = 0

            # pick the right one: ∢(hit, obj.v) < 90° ⇒ hit∙obj.v > 0
            if hitx * obj.vx + hity * obj.vy < 0:
                hitx = -hitx
                if hitx * obj.vx + hity * obj.vy < 0:
                    # mathematicall This Should Not Happen™; bail out with the
                    # default
                    return (deltax, deltay)
            return (hitx, hity)

        # Δy is sufficiently large, generic case
        alpha = float(deltax) / deltay
        cx = alpha * vy - vx
        cy = alpha * (-cx)
        baa1 = BULLET_SPEED * (alpha*alpha + 1)
        try:
            discr = math.sqrt(BULLET_SPEED * baa1 - cx*cx)
        except ValueError:
            #print 'hit_direction', obj, ': no discriminant'
            discr = 0

        hitx = (cx - alpha * discr) / baa1
        hity = (cy - discr) / baa1

        # pick the right one: ∢(hit, delta) < 90° ⇒ hit∙delta > 0
        if hitx * deltax + hity * deltay > 0: 
            return (hitx, hity)

        # wrong one, pick the other:
        hitx = (cx + alpha * discr) / baa1
        hity = (cy + discr) / baa1
        #assert hitx * deltax + hity * deltay >= 0
        return (hitx, hity)

class Saucer(Object):
    def __init__(self, sf):
        Object.__init__(self)
        self.sf = sf # scaling factor (15: big, 14: small)

        if sf == 15: # gross
            self.radius = 20
        elif sf == 14:  # klein
            self.radius = 10

    def __str__(self):
        if self.vx is not None:
            return 'Saucer %i: X(%i, %i), V(%.3f, %.3f)' % (self.sf, self.x,
                self.y, self.vx, self.vy)
        else:
            return 'Saucer %i: X(%i, %i)' % (self.sf, self.x, self.y)

class Game:
    def __init__(self):
        self.asteroids = set()
        self.shots = set()

        self.ship = None
        self.saucer = None

        self.moved_asteroids = set()
        self.moved_shots = set()

        self.release_fire = False

    def start_frame(self):
        '''Setup for beginning evalution of one frame.'''

        self.moved_shots.clear()
        self.moved_asteroids.clear()

    def finish_frame(self):
        '''Postprocessing after a frame was completely evaluated.'''

        for a in self.asteroids - self.moved_asteroids:
            #print 'asteroid', a, 'ceased to exist'
            self.asteroids.remove(a)
        for s in self.shots - self.moved_shots:
            #print 'shot', s, 'ceased to exist'
            self.shots.remove(s)

    def found_asteroid(self, x, y, type, sf):
        # find the asteroid which, based on its current speed, is closest to
        # (x, y)

        min_dist = 0x7ffffffff
        min_asteroid = None

        for asteroid in self.asteroids:
            if asteroid.type == type and asteroid.sf == sf:
                (dx, dy) = asteroid.next_dist_to_point(x, y)
                dist = dx*dx + dy*dy
                if dist < min_dist:
                    min_asteroid = asteroid
                    min_dist = dist

        if min_asteroid and min_dist <= 65:
            min_asteroid.move(x, y)
            self.moved_asteroids.add(min_asteroid)
        else:
            #print 'no matching asteroid at', x, y, ', creating new one'
            a = Asteroid(type, sf)
            a.move(x, y)
            self.asteroids.add(a)
            self.moved_asteroids.add(a)

    def found_shot(self, x, y):
        min_dist = 0x7ffffffff
        min_shot = None

        for shot in self.shots:
            (dx, dy) = shot.next_dist_to_point(x, y)
            dist = dx*dx + dy*dy
            if dist < min_dist:
                min_shot = shot
                min_dist = dist

        if min_shot and min_dist < 65:
            min_shot.move(x, y)
            self.moved_shots.add(min_shot)
        else:
            #print 'no matching shot at', x, y, ', creating new one'
            s = Shot()
            s.move(x, y)
            self.shots.add(s)
            self.moved_shots.add(s)

    def close_shot(self):
        '''Check for a shot which is about to destroy the ship.'''

        for shot in self.shots:
            t = self.ship.collides(shot) 
            if t is not None and t <= 2:
                print 'shot', shot, 'will destroy ship', self.ship
                return True
        return False

    def best_target(self):
        '''Return (squared_dist, object) with the best target.'''

        best_collide = None
        best_near = None
        best_far = None
        best_collide_time = 0x7ffffffff
        best_collide_dist = 0x7ffffffff
        best_near_cot = -100000.
        best_near_dist = 0x7ffffffff
        best_far_dist = 0x7ffffff

        # find
        # - near asteroid with soonest collision
        # - near asteroid with smallest turning
        # - far asteroid with closest distance
        for asteroid in self.asteroids:
            (dx, dy) = self.ship.dist_radius(asteroid)
            dist = dx*dx + dy*dy

            if dist > NEAR_TRESHOLD:
                # far asteroid
                if dist < best_far_dist:
                    best_far = asteroid
                    best_far_dist = dist
            else:
                # near asteroid

                t = self.ship.collides(asteroid)
                if t is not None:
                    #print 'collision with', asteroid, 'in', t
                    if t < best_collide_time:
                        best_collide = asteroid
                        best_collide_time = t
                        best_collide_dist = dist
                    continue

                # cot ∢(ship.look, d) = look ∙ d / |look ⨯ d|
                # the higher the cot, the less we need to turn towards the
                # object
                try:
                    cot = float(self.ship.lookx * dx + self.ship.looky * dy) / (
                        self.ship.lookx * dy - self.ship.looky * dx)
                except ZeroDivisionError:
                    cot = 1000000.

                if cot > best_near_cot:
                    best_near = asteroid
                    best_near_cot = cot
                    best_near_dist = dist

        # select best target; collision > saucer > near > far
        if best_collide:
            return (best_collide_dist, best_collide)
        if self.saucer:
            (dx, dy) = self.ship.dist_radius(self.saucer)
            return (dx*dx + dy*dy, self.saucer)
        if best_near:
            return (best_near_dist, best_near)
        return (best_far_dist, best_far)

    def get_keys(self, keys, time):
        '''Fill KeyPacket keys with the keys to be sent next.
        
        This is the central strategic function.'''

        if not self.ship:
            return

        # TODO: this is currently broken, disable
        #if self.close_shot():
            # staying here would be a certain death, so take chances
        #    keys.hyperspace(True)

        (target_dist, target) = self.best_target()

        if target:
            self.attack(target, keys)

            if target_dist < 27**2:     # cross fingers
                keys.hyperspace(True)
            
            # accelerate carefully
            if target_dist > 400**2:
                keys.thrust(True)
            elif target_dist > NEAR_TRESHOLD and self.ship.speed() < 9:
                keys.thrust(True)
        else:
            # about to finish level; head to the center, since asteroids always
            # appear at the borders
            self.go_center(keys)

        # need to release the trigger after one shot
        if self.release_fire and not keys.firing():
            self.release_fire = False

    def attack(self, target, keys):
        '''Press the appropriate key to attack target.'''

        # normalized hit direction vector
        (hitx, hity) = self.ship.hit_direction(target)

        # normalized look vector
        l = math.sqrt(self.ship.lookx*self.ship.lookx + self.ship.looky*self.ship.looky)
        lookx = self.ship.lookx / l
        looky = self.ship.looky / l

        # look ⨯ hit (vector product), sin ∢(look, hit)
        vp =  lookx * hity - looky * hitx
        # look ∙ hit (scalar product), cos ∢(look, hit)
        sp =  lookx * hitx + looky * hity

        # maximum angle error under which we can still hit
        (distx, disty) = dist(target.x, target.y, self.ship.x, self.ship.y)
        hit_angle = target.radius/math.sqrt(distx*distx + disty*disty)

        #print '\r', lookx, looky, 'H:', hitx, hity, 'vp:', vp, 'sp:', sp, 'ha:', hit_angle, '          ',

        # vp points upwards (> 0) -> rotate left
        if vp > 0:
            keys.left(True)
        elif vp < 0:
            keys.right(True)

        if not self.release_fire and sp > 0 and abs(vp) <= hit_angle:
             # well aimed, fire!
            keys.fire(True)
            self.release_fire = True

    def go_center(self, keys):
        '''Press appropriate keys to head to the screen center'''

        # vector to screen center
        centerx = 512 - self.ship.x
        centery = 384 - self.ship.y

        # normalized look vector
        l = math.sqrt(self.ship.lookx*self.ship.lookx + self.ship.looky*self.ship.looky)
        lookx = self.ship.lookx / l
        looky = self.ship.looky / l

        if centerx * centerx + centery * centery < 40000:
            # decelerate to stop

            s = self.ship.speed()
            if s <= 4:
                # slow enough
                return

            # normalized speed vector
            sx = self.ship.vx / s
            sy = self.ship.vy / s

            # look ⨯ speed (vector product), sin ∢(look, speed)
            vp =  lookx * sy - looky * sx
            # look ∙ center (scalar product), cos ∢(look, speed)
            sp =  lookx * sx + looky * sy

            # vp points upwards (> 0) -> rotate right
            if vp > 0:
                keys.right(True)
            elif vp < 0:
                keys.left(True)

            if sp < 0 and abs(vp) <= 0.20 and self.ship.speed() > 2:
                # pointing to reverse speed direction within 22 degrees angle, go!
                keys.thrust(True)
        else:
            # turn towards center

            # adjust ideal vector to center for the influence of current ship
            # speed, and normalize
            if self.ship.vx:
                centerx -= 5 * self.ship.vx
                centery -= 5 * self.ship.vy
            l = math.sqrt(centerx * centerx + centery * centery)
            centerx /= l
            centery /= l

            # look ⨯ center (vector product), sin ∢(look, center)
            vp =  lookx * centery - looky * centerx
            # look ∙ center (scalar product), cos ∢(look, center)
            sp =  lookx * centerx + looky * centery

            # vp points upwards (> 0) -> rotate left
            if vp > 0:
                keys.left(True)
            elif vp < 0:
                keys.right(True)

            if sp > 0 and abs(vp) <= 0.25 and self.ship.speed() <= 25:
                # pointing to center within 30 degrees angle, go!
                keys.thrust(True)

class FramePacket:
        recv_struct = struct.Struct('512H2c') # H = 2 byte
        # 1024 bytes vectorram + 1 byte framenbr + 1 byte ping

        def __init__(self, sd):
            self.sd = sd
            self.frame_number = None
            self.ping = None
            self.vector_ram = None

        def receive(self):
            global invalid_packets

            raw_data = self.sd.recv(PACKET_SIZE)
            try:
                data = self.recv_struct.unpack(raw_data)
                invalid_packets = 0
            except struct.error:
                print 'Received invalid packet, ignoring'
                invalid_packets += 1
                if invalid_packets > 5:
                    print 'Cannot resync to server, aborting.'
                    sys.exit(1)
                return

            self.frame_number = ord(data[-2])
            self.ping = ord(data[-1])
            self.vector_ram = data[0:-2]

class KeyPacket:
    KEY_HYPERSPACE = 1
    KEY_FIRE = 2
    KEY_THRUST = 4
    KEY_RIGHT = 8
    KEY_LEFT = 16

    send_struct = struct.Struct('8c')

    def __init__(self, sd):
        self.sd = sd
        self.keys = 0
        self.ping = 0

    def send(self):
        buffer = self.send_struct.pack('c', 't', 'm', 'a', 'm', 'e', chr(self.keys), chr(self.ping))
        self.sd.sendto(buffer, (HOST,PORT))

    def clear(self):
        self.keys = 0

    def hyperspace(self, b):
        if b:
            self.keys |= self.KEY_HYPERSPACE
        else:
            self.keys &= ~self.KEY_HYPERSPACE

    def fire(self, b):
        if b:
            self.keys |= self.KEY_FIRE
        else:
            self.keys &= ~self.KEY_FIRE

    def thrust(self, b):
        if b:
            self.keys |= self.KEY_THRUST
        else:
            self.keys &= ~self.KEY_THRUST

    def right(self, b):
        if b:
            self.keys |= self.KEY_RIGHT
            self.left(False)
        else:
            self.keys &= ~self.KEY_RIGHT

    def left(self, b):
        if b:
            self.keys |= self.KEY_LEFT
            self.right(False)
        else:
            self.keys &= ~self.KEY_LEFT

    def firing(self):
        return self.keys & self.KEY_FIRE > 0

def interpret_screen(vector_ram, game):
    vs = 0
    vx = vy = vz = 0
    dx = dy = sf = 0
    ship_detect = 0

    found_ship = False
    found_saucer = False

    game.start_frame()

    pc = 1
    while True:
        op = vector_ram[pc] >> 12

        if op == 0xa: # LABS
            vy = vector_ram[pc] & 0x3ff
            vx = vector_ram[pc+1] & 0x3ff
            vs = vector_ram[pc+1] >> 12
        elif op in (0xb, 0xd, 0xe): # HALT, RTLS, JMPL
            break
        elif op == 0xc: # JSRL
            address = vector_ram[pc] & 0xfff
            if address == 0x8f3:
                game.found_asteroid(vx, vy, 1, vs)
            elif address == 0x8ff:
                game.found_asteroid(vx, vy, 2, vs)
            elif address == 0x90d:
                game.found_asteroid(vx, vy, 3, vs)
            elif address == 0x91a:
                game.found_asteroid(vx, vy, 4, vs)
            elif address == 0x929:
                found_saucer = True
                if not game.saucer:
                    game.saucer = Saucer(vs)
                game.saucer.move(vx, vy)
        elif op == 0xf: # SVEC
            pass
        else:
            dy = vector_ram[pc] & 0x3ff
            if vector_ram[pc] & 0x400:
                dy = -dy
            dx = vector_ram[pc+1] & 0x3ff
            if vector_ram[pc+1] & 0x400:
                dx = -dx
            sf = op
            vz = vector_ram[pc+1] >> 12
            if dx == 0 and dy == 0 and vz == 15:
                game.found_shot(vx, vy)
            if op == 6 and vz == 12 and dx != 0 and dy != 0:
                if ship_detect == 0:
                    v1x = dx
                    v1y = dy
                    ship_detect += 1
                elif ship_detect == 1:
                    found_ship = True
                    if not game.ship:
                        game.ship = Ship()
                    game.ship.lookx = v1x - dx
                    game.ship.looky = v1y - dy
                    game.ship.move(vx, vy)
            elif ship_detect == 1:
                ship_detect = 0
                        
        if op <= 0xa:
            pc += 1
        if op != 0xe: # JMPL
            pc += 1

    if not found_ship:
        game.ship = None
    if not found_saucer:
        game.saucer = None
    game.finish_frame()

def main(host=HOST):
    sd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

    game = Game()
    key_packet = KeyPacket(sd)
    frame_packet = FramePacket(sd)

    time = 0
    prev_frame_number = 0

    while True:
        time += 1
        key_packet.ping = (key_packet.ping + 1) % 256
        prev_frame_number = (prev_frame_number + 1) % 256

        key_packet.send()
        frame_packet.receive()

        if (frame_packet.frame_number != prev_frame_number or
            frame_packet.ping != key_packet.ping):
            print 'latency: %d' % (key_packet.ping - frame_packet.ping),
            print 'lost frames: %d' % (frame_packet.frame_number - prev_frame_number)

            prev_frame_number = frame_packet.frame_number

        interpret_screen(frame_packet.vector_ram, game)

        key_packet.clear()
        game.get_keys(key_packet, time)


if __name__ == '__main__':
    if len(sys.argv) == 1:
        main()
    elif len(sys.argv) == 2:
        main(sys.argv[1])
    else:
        print 'Usage: %s [ mame-ip ]' % sys.argv[0]
