require('Module:Lua class')
local libraryUtil = require('libraryUtil')
local TableTools = require('Module:TableTools')
local warn = require('Module:Warning')

local basic_types = {boolean=1, string=1, number=1, ['function']=1}


local f_hashes = {} -- so that function elements can be properly compared and ordered in frozenset._hash
local f_hashes_mt = {
	__index = function (t, key)
		local h = tonumber('0x' .. mw.hash.hashValue('fnv1a32', tostring(os.time() + math.random())))
		f_hashes[key] = h
		return h
	end,
	__mode = 'k'
}
setmetatable(f_hashes, f_hashes_mt)


local frozenset, _frozenset = class('frozenset', {
	__init = function (self, args)
		local elements = {}
		self._elements = elements
		self._elements_obj = {}

		if #args == 0 then -- for performance
			return
		elseif #args == 1 then
			local arg = args[1]
			if type(arg) == 'string' then
				local c
				for i = 1, mw.ustring.len(arg) do
					c = mw.ustring.sub(arg, i,i)
					elements[c] = c
				end
				return
			elseif pcall(pairs, arg) or pcall(ipairs, arg) then
				args = arg
			end
		end

		if TableTools.isArrayLike(args) then
			for i, v in ipairs(args) do
				if type(v) == 'table' or isinstance(v) and v.hash == nil then
					error(("TypeError: invalid element #%d type (got %s, which is not hashable)"):format(i, type(v)), 3)
				end
				self._set(v)
			end
		else
			for k in pairs(args) do
				if type(k) == 'table' or isinstance(k) and k.hash == nil then
					error(("TypeError: invalid element type (got %s, which is not hashable)"):format(type(k)), 3)
				end
				self._set(k)
			end
		end
	end,

	_get = function (self, elem)
		if basic_types[type(elem)] then
			return self._elements[elem]
		else
			local elements_obj = self._elements_obj
			local h = elem.hash()

			while elements_obj[h] ~= nil do
				if elements_obj[h] == elem then
					return elem
				end
				h = h + 1
			end
		end
	end,

	_set = function (self, elem)
		if basic_types[type(elem)] then
			self._elements[elem] = elem
		else
			local elements_obj = self._elements_obj
			local h = elem.hash()

			while elements_obj[h] ~= nil do
				if elements_obj[h] == elem then
					return
				end
				h = h + 1
			end
			elements_obj[h] = elem -- otherwise different objects with the same content would duplicate
		end
	end,

	__pairs = function (self)
		local elems = self._elements
		local k, v

		local function iterator()
			k, v = next(elems, k)
			if k == nil and elems == self._elements then
				elems = self._elements_obj
				k, v = next(elems, k)
			end
			return v -- nil at the very end
		end

		return iterator
	end,

	__ipairs = function (self)
		error("IterationError: a set is unordered, use 'pairs' instead", 2)
	end,

	_keySort = function (item1, item2)
		-- "number" < "string", so numbers will be sorted before strings.
		local type1, type2 = type(item1), type(item2)
		if type1 ~= type2 then
			return type1 < type2
		elseif type1 == 'number' or type1 == 'string' then
			return item1 < item2
		elseif type1 == 'boolean' then
			return tostring(item1) < tostring(item2)
		else
			local hash1, hash2
			if type1 == 'function' then
				hash1, hash2 = f_hashes[item1], f_hashes[item2]
				if hash1 == hash2 then
					warn(("HashWarning: function hash collision at %d"):format(hash1), 2)--what should be the level?
				end
			else
				hash1, hash2 = item1.hash(), item2.hash()
				if hash1 == hash2 then
					warn(("HashWarning: object hash collision at %d"):format(hash1), 2)
				end
			end
			return hash1 < hash2
		end
	end,

	_hash = function (self)
		if self.__hash ~= nil then
			return self.__hash
		end

		-- frozensets with the same elements (meaning equal) may have a different order, so 'order' them before hashing
		local ordered_elems = TableTools.keysToList(self, self._keySort, true)

		-- convert elements to strings for table.concat
		local elemType
		for i, elem in ipairs(ordered_elems) do
			elemType = type(elem)
			if elemType == 'number' or elemType == 'boolean' then
				ordered_elems[i] = tostring(elem)
			elseif elemType == 'string' then
				ordered_elems[i] = "'" .. elem .. "'"
			elseif elemType == 'function' then
				ordered_elems[i] = 'f' .. f_hashes[elem]
			else
				ordered_elems[i] = 'o' .. elem.hash()
			end
		end

		local str = '{' .. table.concat(ordered_elems, ',') .. '}' -- wrap in {} to differentiate from tuple
		self.__hash = tonumber('0x' .. mw.hash.hashValue('fnv1a32', str))
		return self.__hash
	end,

	__tostring = function (self)
		local string_elems = {}
		local elemType

		for elem in pairs(self) do
			elemType = type(elem)
			if elemType == 'string' then
				string_elems[#string_elems+1] = "'" .. elem .. "'"
			elseif elemType == 'function' then
				string_elems[#string_elems+1] = 'f' .. f_hashes[elem]
			else
				string_elems[#string_elems+1] = tostring(elem)
			end
		end

		local str = '{' .. table.concat(string_elems, ', ') .. '}'
		return str
	end,

	len = function (self)
		return TableTools.size(self._elements) + TableTools.size(self._elements_obj)
	end,

	has = function (self, elem)
		if isinstance(elem, 'set') then
			elem = frozenset{elem}
		elseif type(elem) == 'table' or isinstance(elem) and elem.hash == nil then
			error(("TypeError: invalid element type (got %s, which is not hashable)"):format(type(elem)), 2)
		end
		return self._get(elem) ~= nil and true or false
	end,

	isdisjoint = function (self, other)
		libraryUtil.checkTypeMulti('isdisjoint', 1, other, {'set', 'frozenset'})
		for elem in pairs(other) do
			if self._get(elem) ~= nil then
				return false
			end
		end
		return true
	end,

	issubset = function (self, other)
		return self <= frozenset{other}
	end,

	__le = function (a, b)
		for elem in pairs(a) do
			if b._get(elem) == nil then
				return false
			end
		end
		return true
	end,

	__lt = function (a, b)
		return a <= b and a.len() < b.len() -- is calculating a's length during its traversal in __le faster?
	end,

	issuperset = function (self, other)
		return self >= frozenset{other}
	end,

	union = function (self, ...)
		local sum = set{self}
		sum.update(...)
		return sum
	end,

	__add = function (a, b)
		local sum, _sum = a.__class{}
		for elem in pairs(a) do
			_sum._set(elem)
		end
		for elem in pairs(b) do
			_sum._set(elem)
		end
		return sum
	end,

	intersection = function (self, ...)
		local product = set{self}
		product.intersection_update(...)
		return product
	end,

	__mul = function (a, b)
		local product, _product = a.__class{}
		for elem in pairs(a) do
			if b._get(elem) ~= nil then
				_product._set(elem)
			end
		end
		return product
	end,

	difference = function (self, ...)
		local difference = set{self}
		difference.difference_update(...)
		return difference
	end,

	__sub = function (a, b)
		local difference, _difference = a.__class{}
		for elem in pairs(a) do
			if b._get(elem) == nil then
				_difference._set(elem)
			end
		end
		return difference
	end,

	symmetric_difference = function (self, other)
		return self ^ frozenset{other}
	end,

	__pow = function (a, b)
		local symm_diff, _symm_diff = a.__class{}
		for elem in pairs(a) do
			if b._get(elem) == nil then
				_symm_diff._set(elem)
			end
		end
		for elem in pairs(b) do
			if a._get(elem) == nil then
				_symm_diff._set(elem)
			end
		end
		return symm_diff
	end,

	copy = function (self)
		return (self.__class{self}) -- to not leak the private instance
	end,

	__eq = function (a, b)
		return a <= b and a >= b
	end,

	__staticmethods = {'_keySort'},
	__protected = {'_get', '_set'}
})


local set = class('set', frozenset, {
	_del = function (self, elem)
		if basic_types[type(elem)] then
			self._elements[elem] = nil
		else
			local elements_obj = self._elements_obj
			local h = elem.hash()

			while elements_obj[h] ~= nil do
				if elements_obj[h] == elem then
					elements_obj[h] = nil
					return
				end
				h = h + 1
			end
		end
	end,

	update = function (self, ...)
		local others, other = {...}
		for i = 1, select('#', ...) do
			other = frozenset{others[i]}
			for elem in pairs(other) do
				self._set(elem)
			end
		end
	end,

	intersection_update = function (self, ...)
		local others, _, _other = {...}
		for i = 1, select('#', ...) do
			_, _other = _frozenset{others[i]}
			for elem in pairs(self) do -- probably faster than iterating through (likely longer) "other"
				if _other._get(elem) == nil then
					self._del(elem)
				end
			end
		end
	end,

	difference_update = function (self, ...)
		local others, _, _other = {...}
		for i = 1, select('#', ...) do
			_, _other = _frozenset{others[i]}
			for elem in pairs(self) do
				if _other._get(elem) ~= nil then
					self._del(elem)
				end
			end
		end
	end,

	symmetric_difference_update = function (self, other)
		local _, _other = _frozenset{other}
		for elem in pairs(self) do
			if _other._get(elem) ~= nil then
				self._del(elem)
			end
		end
		for elem in pairs(_other) do
			if self._get(elem) == nil then
				self._set(elem)
			end
		end
	end,

	add = function (self, elem)
		if type(elem) == 'table' or isinstance(elem) and elem.hash == nil then
			error(("TypeError: invalid element type (got %s, which is not hashable)"):format(type(elem)), 2)
		end
		self._set(elem)
	end,

	remove = function (self, elem)
		if isinstance(elem, 'set') then
			elem = frozenset{elem}
		elseif type(elem) == 'table' or isinstance(elem) and elem.hash == nil then
			error(("TypeError: invalid element type (got %s, which is not hashable)"):format(type(elem)), 2)
		end
		if self._get(elem) == nil then
			error(("KeyError: %s"):format(tostring(elem)), 2)
		end
		self._del(elem)
	end,

	discard = function (self, elem)
		if isinstance(elem, 'set') then
			elem = frozenset{elem}
		elseif type(elem) == 'table' or isinstance(elem) and elem.hash == nil then
			error(("TypeError: invalid element type (got %s, which is not hashable)"):format(type(elem)), 2)
		end
		self._del(elem)
	end,

	pop = function (self)
		local k, v = next(self._elements)
		if k == nil then
			k, v = next(self._elements_obj)
			if k == nil then
				error("KeyError: pop from an empty set", 2)
			end
		end
		self._del(v)
		return v
	end,

	clear = function (self)
		self._elements = {}
		self._elements_obj = {}
	end,

	__protected = {'_del'}
})

return {frozenset, set}