﻿using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Gamelogic.Grids;
using UnityEngine;

public struct LayeredPoint<TPoint> : IGridPoint<LayeredPoint<TPoint>> where TPoint : IGridPoint<TPoint>
{
	public TPoint point;
	public int layer;

	public bool Equals(LayeredPoint<TPoint> other)
	{
		return layer == other.layer && point.Equals(other.point);
	}

	public int DistanceFrom(LayeredPoint<TPoint> other)
	{
		return point.DistanceFrom(other.point) + Mathf.Abs(layer - other.layer);
	}
}

public class LayeredGrid<TCell, TPoint> : IGrid<TCell, LayeredPoint<TPoint>> where TPoint : IGridPoint<TPoint>
{
	private readonly IGrid<TCell, TPoint>[] layers;

	public int LayerCount
	{
		get
		{
			return layers.Length;
		}
	}

	public LayeredGrid(int width, int height, int layerCount, Func<TPoint, bool> contains, Func<int, int, Func<TPoint, bool>, IGrid<TCell, TPoint>> constructor)
	{
		layers = new IGrid<TCell, TPoint>[layerCount];
		
		for (int i = 0; i < layerCount; i++)
		{
			layers[i] = constructor(width, height, contains);
		}
	}

	public LayeredGrid(int width, int height, int layerCount, Func<TPoint, bool> contains, Func<int, int, int, Func<TPoint, bool>, IGrid<TCell, TPoint>> constructor)
	{
		layers = new IGrid<TCell, TPoint>[layerCount];

		for (int i = 0; i < layerCount; i++)
		{
			layers[i] = constructor(i, width, height, contains);
		}
	}

	public static LayeredGrid<TCell, TPoint2> Make<TShapeInfo, TGrid, TPoint2, TVectorPoint, TShapeOp>(
		TShapeInfo[] shapes)

		where TShapeInfo : AbstractShapeInfo<TShapeInfo, TGrid, TPoint2, TVectorPoint, TShapeOp> 
		where TPoint2 : IGridPoint<TPoint2>, ISplicedVectorPoint<TPoint2, TVectorPoint>
		where TVectorPoint : IVectorPoint<TVectorPoint>
		where TGrid : IGrid<TCell, TPoint2>

	{
		var layers = new IGrid<TCell, TPoint2>[shapes.Length];
		for (int i = 0; i < layers.Length; i++)
		{
			layers[i] = shapes[i].EndShape();
		}

		return new LayeredGrid<TCell, TPoint2>(layers);
	}

	public static LayeredGrid<TCell, TPoint2> Make<TShapeInfo, TGrid, TPoint2, TVectorPoint, TShapeOp>(int layerCount,
		TShapeInfo shape)

		where TShapeInfo : AbstractShapeInfo<TShapeInfo, TGrid, TPoint2, TVectorPoint, TShapeOp>
		where TPoint2 : IGridPoint<TPoint2>, ISplicedVectorPoint<TPoint2, TVectorPoint>
		where TVectorPoint : IVectorPoint<TVectorPoint>
		where TGrid : IGrid<TCell, TPoint2>
	{
		var layers = new IGrid<TCell, TPoint2>[layerCount];

		if (layerCount > 0)
		{
			layers[0] = shape.EndShape();
		}

		for (int i = 1; i < layerCount; i++)
		{
			layers[i] = layers[0].CloneStructure<TCell>();
		}

		return new LayeredGrid<TCell, TPoint2>(layers);
	}

	protected LayeredGrid(IGrid<TCell, TPoint>[] layers)
	{
		this.layers = layers;
	}
 
	public bool Contains(LayeredPoint<TPoint> point)
	{
		if (point.layer < 0 || point.layer >= layers.Length)
		{
			return false;
		}

		return layers[point.layer].Contains(point.point);
	}

	public IEnumerator<LayeredPoint<TPoint>> GetEnumerator()
	{
		for (int i = 0; i < layers.Length; i++)
		{
			foreach (var layerPoint in layers[i])
			{
				yield return new LayeredPoint<TPoint>
				{
					point = layerPoint,
					layer = i
				};
			}
		}
	}

	IEnumerator IEnumerable.GetEnumerator()
	{
		return GetEnumerator();
	}

	public IGrid<TNewCell, LayeredPoint<TPoint>> CloneStructure<TNewCell>()
	{
		return new LayeredGrid<TNewCell, TPoint>(layers.Select(layer => layer.CloneStructure<TNewCell>()).ToArray());
	}

	public IEnumerable<LayeredPoint<TPoint>> GetAllNeighbors(LayeredPoint<TPoint> point)
	{
		//override this for your use!
		throw new NotImplementedException();
	}

	public IEnumerable<LayeredPoint<TPoint>> GetLargeSet(int n)
	{
		var largeSet = new List<LayeredPoint<TPoint>>();

		for (int i = -n; i <= n; i++)
		{
			var largeSetLayer = layers[i].GetLargeSet(n);

			largeSet.AddRange(largeSetLayer.Select(largeSetPoint => new LayeredPoint<TPoint>
			{
				point = largeSetPoint, layer = i
			}));
		}

		return largeSet;
	}

	public IEnumerable<LayeredPoint<TPoint>> GetStoragePoints()
	{
		var storagePoints = new List<LayeredPoint<TPoint>>();
		
		for (int i = 0; i < LayerCount; i++)
		{
			storagePoints.AddRange(layers[i].GetStoragePoints().Select(layerPoint => new LayeredPoint<TPoint>
			{
				point = layerPoint, layer = i
			}));
		}

		return storagePoints;
	}

	public TCell this[LayeredPoint<TPoint> point]
	{
		get
		{
			return layers[point.layer][point.point];
		}

		set
		{
			layers[point.layer][point.point] = value;
		}
	}

	public IEnumerable<TCell> Values
	{
		get
		{
			return this.Select(p => this[p]);
		}
	}

	
}

public class SimpleLayeredMap<TPoint> : IMap3D<LayeredPoint<TPoint>> 
	where TPoint : IGridPoint<TPoint>
{
	private readonly float layerDistance;
	private readonly float layerOffset;
	private readonly IMap<TPoint> baseMap;

	public SimpleLayeredMap(IMap<TPoint> baseMap, float layerDistance, float layerOffset)
	{
		this.layerDistance = layerDistance;
		this.layerOffset = layerOffset;
		this.baseMap = baseMap;
	}


	public IMap<LayeredPoint<TPoint>> To2D()
	{
		throw new NotImplementedException();
	}

	public Vector3 this[LayeredPoint<TPoint> point]
	{
		get
		{
			var point2 = baseMap[point.point];
			var layerHeight = point.layer*layerDistance;

			return new Vector3(point2.x, layerHeight + layerOffset, point2.y);
		}
	}

	public LayeredPoint<TPoint> this[Vector3 point]
	{
		get
		{
			int layerIndex = Mathf.RoundToInt((point.y - layerOffset)/layerDistance);

			var point2d = baseMap[new Vector2(point.x, point.z)];
			
			return new LayeredPoint<TPoint>
			{
				point = point2d,
				layer = layerIndex
			};
		}
	}
}
