#pragma once

#include "utils/hex_form.hpp"

#include "nes_memory_map.hpp"
#include "reference.hpp"
#include "instruction.hpp"
#include "nes_memory_mapper.hpp"
#include "state_machine.hpp"

namespace cdp {
enum CodeDataParserEnum{
    BLANK=0,
    CODE=1,
    DATA=2,
    MAPPED_ORG=0x0C,
    MAPPED_ORG0=0x00,
    MAPPED_ORG1=0x04,
    MAPPED_ORG2=0x08,
    MAPPED_ORG3=0x0C,
    IND_CODE=0x10,
    IND_DATA=0x20,
    PCM =0x40,
};
}



/** Infomation of Memory Access for a Instruction*/
class MemoryAccess
{
	MemoryID m_none;
	MemoryID m_a;
	MemoryID m_x;
	MemoryID m_y;
	
	MemoryID m_accessed;
	MemoryID m_used;
	MemoryID m_dead;
public:
	MemoryAccess(StateMachine& state)
			:m_none(state.mapper().size()),
			 m_a(m_none+1),
			 m_x(m_none+2),
			 m_y(m_none+3),
			 m_accessed(m_none),
			 m_used(m_none),
			 m_dead(m_none)
	{
		init(state);
	}
	MemoryAccess(const MemoryAccess& that)
			:m_none(that.m_none),
			 m_a(that.m_a),
			 m_x(that.m_x),
			 m_y(that.m_y),
			 m_accessed(that.m_accessed),
			 m_used(that.m_used),
			 m_dead(that.m_dead)
	{
	}
	bool isRegister(MemoryID id) { return id > m_none; }
	MemoryID usedID()const { return m_used; }
	MemoryID accessedID()const { return m_accessed; }
	MemoryID deadID()const { return m_dead; }
	MemoryID memory(StateMachine& state)
	{
		const NesMemoryMapper& mapper=state.mapper();
		PInstruction instr=state.nextInstruction();
		Word operand = instr->operand();
		AddressingMode mode = instr->mode();
		MemoryID id = mapper.id(operand);
		switch (mode) {
		case opcode::ZERO: case opcode::ABS:
		case opcode::ZERO_X: case opcode::ABS_X:
		case opcode::ZERO_Y: case opcode::ABS_Y: 
			m_accessed = mapper.id(operand);
			return operand;
		case opcode::PRE_IND: // ex. lda ($50, x)
			m_accessed=mapper.id(
					mapper.wordData(
							mapper.id(bit8::addressIndex(
											  operand, state.x()))) );
			return operand;
		case opcode::POST_IND:// ex. lda ($50), y
			m_accessed = mapper.id(mapper.wordData(id));
			
			return operand;
		case opcode::REG_A:
			return m_a;
		case opcode::IND:
			m_accessed = mapper.wordData(id);
			return operand;
		case opcode::IMM:
		default:
			return m_none;
		}
	}
	
	void init(StateMachine& state) 
	{
		PInstruction instr=state.nextInstruction();
		
		switch (instr->type()) {
		case opcode::JMP: m_used = memory(state); break;
		case opcode::LDA: m_used = memory(state); m_dead = m_a; break;
		case opcode::LDX: m_used = memory(state); m_dead = m_x; break;
		case opcode::LDY: m_used = memory(state); m_dead = m_y; break;
			
		case opcode::TAX: m_used = m_a; m_dead = m_x; break;
		case opcode::TAY: m_used = m_a; m_dead = m_y; break;
		case opcode::TSX: break;
		case opcode::TXA: m_used = m_x; m_dead = m_a; break;
		case opcode::TXS: break;
		case opcode::TYA: m_used = m_y; m_dead = m_a; break;
			
		case opcode::PLA: m_used = state.stackAddress(); m_dead = m_a;
			break;
			
		case opcode::STA: m_used = m_a; m_dead = memory(state); break;
		case opcode::STX: m_used = m_x; m_dead = memory(state); break;
		case opcode::STY: m_used = m_y; m_dead = memory(state); break;
			
		default:
			return;
		}
		
	}
};
/** Class ReferenceJumpTable is find jump table by state machine.*/
class ReferenceJumpTable
{
public:
	ByteSequence& m_flags;
	std::vector<MemoryAccess> m_history;
	Word m_table_address;
	std::vector<Word> jump_table;
	
	ReferenceJumpTable(ByteSequence& flags)
			:m_flags(flags)
	{
	}
	Word tableAddress()const { return m_table_address; }
    bool found()const { return !jump_table.empty(); }
    
	const std::vector<Word>& parse(StateMachine& state)
	{
		jump_table.clear();
		
		const NesMemoryMapper& mapper=state.mapper();
		
		while (true) {
			m_history.push_back(MemoryAccess(state));
			PInstruction instr = state.nextInstruction();
			
			Word addr = instr->address();
			
			if (addr >= nes::CODE_END) {
				return jump_table;
			}
			MemoryID id = mapper.id(addr);
			
			for (unsigned i=0; i<instr->bytelength(); ++i) {
				m_flags[id+i] |= cdp::IND_DATA;
			}
			
			OpcodeType optype = instr->type();
			AddressingMode mode = instr->mode();
			ReferenceType type = ReferenceHelper::map(optype);
			
			if (type == reference::JUMP && mode == opcode::IND) {
				break;
			} else  {
				state.execute(instr);
			}
		}
		MemoryID tableID = findTable();
		if (tableID == 0) {
			return jump_table;
		}
		
		m_table_address = mapper.hardAddress(tableID);
		
		int table_is_one_byte_after=0;
		
		for (Word i=0; i<16; i+=2) {
			Word stay=mapper.wordData(tableID + i);
			Word change=mapper.wordData(tableID + i+1);
			
			if (nes::CODE_OFFSET <= stay && stay < nes::CODE_END) {
				table_is_one_byte_after-=1;
			}
			if (nes::CODE_OFFSET <= change && change < nes::CODE_END) {
				table_is_one_byte_after+=1;
			}
		}
		if (table_is_one_byte_after > 0) {
			tableID+=1;
			m_table_address+=1;
		}
		
		for (MemoryID i=tableID; i<nes::CODE_END-1; i+=2) {
			Word addr = mapper.wordData(i);
			if (nes::CODE_OFFSET < addr && addr <= nes::CODE_END) {
				jump_table.push_back(addr);
			} else {
				break;
			}
		}
		unsigned bad_points=0;
		Word return_addr = state.returnAddress();
		for (unsigned i=0; i<jump_table.size(); ++i) {
			long delta = jump_table[i] - return_addr;
			delta = (delta>0)?delta:-delta;
			if (delta >= nes::BANK_SIZE) {
				bad_points++;
			}
		}
		if (bad_points * 2 > jump_table.size()) {
			jump_table.clear();
		}
		return jump_table;
	}
private:
	MemoryID findTable()
	{
		unsigned size = m_history.size();
		MemoryID usedID=m_history[size-1].usedID();
		
		for (int i=size-2; i > 0; --i) {
			MemoryAccess& memory=m_history[i];
			
			if (memory.deadID() == usedID) {
				usedID = memory.usedID();
				if (! memory.isRegister(usedID)){
					return memory.accessedID();
				} 
			}
		}
		return 0;
	}
};

