package com.openexchange.office.rt2.core.ws;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import com.openexchange.exception.OXException;
import com.openexchange.office.rt2.core.config.RT2Const;
import com.openexchange.office.rt2.core.exception.RT2SessionInvalidException;
import com.openexchange.office.rt2.core.exception.RT2TypedException;
import com.openexchange.office.rt2.protocol.value.RT2SessionIdType;
import com.openexchange.office.tools.common.error.ErrorCode;
import com.openexchange.office.tools.service.session.SessionService;
import com.openexchange.session.Session;

@Service
public class RT2SessionCountValidator {

	public static final String KEY_MAX_SESSIONS_PER_NODE = "com.openexchange.office.maxOpenDocumentsPerUser";
	
	@Autowired
	private SessionService sessionService;
	
	@Value("${" + KEY_MAX_SESSIONS_PER_NODE + ":" + RT2Const.DEFAULT_MAX_SESSIONS_PER_NODE + "}")
	private Integer maxSessionsPerUser;
	
	private ConcurrentHashMap<Integer, Set<RT2ChannelId>> countSessionsOfUser = new ConcurrentHashMap<>();
	
	public RT2SessionCountValidator() {}
	
	public RT2SessionCountValidator(int maxSessionsPerUser) {
		this.maxSessionsPerUser = maxSessionsPerUser;		
	}
	
	public void addSession(RT2SessionIdType sessionId, RT2ChannelId channelId) throws RT2TypedException, OXException {
		Session session = sessionService.getSession4Id(sessionId.getValue());
		if (session == null) {
			throw new RT2SessionInvalidException(new ArrayList<>(), RT2SessionInvalidException.SESSION_ID_ADD_INFO, sessionId.getValue());
		}
		Set<RT2ChannelId> channelIds = Collections.synchronizedSet(new HashSet<>());
		channelIds = countSessionsOfUser.putIfAbsent(session.getUserId(), channelIds);
		if (channelIds == null) {
			channelIds = countSessionsOfUser.get(session.getUserId());
		}
		synchronized (channelIds) {
			channelIds.add(channelId);
			if (channelIds.size() > maxSessionsPerUser) {
				channelIds.remove(channelId);
				throw new RT2TypedException(ErrorCode.TOO_MANY_CONNECTIONS_ERROR, new ArrayList<>());
			}
		}
	}
	
	public void removeSession(RT2ChannelId channelId) {
		synchronized (countSessionsOfUser) {
			final Set<Integer> emptyEntries = new HashSet<>();
			countSessionsOfUser.entrySet().stream().forEach(p -> {				
				p.getValue().remove(channelId);
				if (p.getValue().isEmpty()) {
					emptyEntries.add(p.getKey());
				}
			});
			countSessionsOfUser.keySet().removeAll(emptyEntries);
		}		
	}
	
	public Map<Integer, Integer> getCountSessionsOfUsers() {
		Map<Integer, Integer> res = new HashMap<>();
		for (Map.Entry<Integer, Set<RT2ChannelId>> entry : countSessionsOfUser.entrySet()) {
			res.put(entry.getKey(), entry.getValue().size());
		}
		return res;
	}
}
