package com.openexchange.office.rt2.ws;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import com.openexchange.exception.OXException;
import com.openexchange.office.rt2.exception.RT2SessionInvalidException;
import com.openexchange.office.rt2.exception.RT2TypedException;
import com.openexchange.office.rt2.protocol.value.RT2SessionIdType;
import com.openexchange.office.session.SessionService;
import com.openexchange.office.tools.error.ErrorCode;
import com.openexchange.session.Session;

public class RT2SessionCountValidator {
	
	private final SessionService sessionService;
	private final int maxSessionsPerUser;
	
	private ConcurrentHashMap<Integer, Set<RT2ChannelId>> countSessionsOfUser = new ConcurrentHashMap<>();
	
	public RT2SessionCountValidator(SessionService sessionService, int maxSessionsPerUser) {
		this.sessionService = sessionService;
		this.maxSessionsPerUser = maxSessionsPerUser;
	}

	public void addSession(RT2SessionIdType sessionId, RT2ChannelId channelId) throws RT2TypedException, OXException {
		Session session = sessionService.getSession4Id(sessionId.getValue());
		if (session == null) {
			throw new RT2SessionInvalidException(new ArrayList<>());
		}
		Set<RT2ChannelId> channelIds = Collections.synchronizedSet(new HashSet<>());
		channelIds = countSessionsOfUser.putIfAbsent(session.getUserId(), channelIds);
		if (channelIds == null) {
			channelIds = countSessionsOfUser.get(session.getUserId());
		}
		synchronized (channelIds) {
			channelIds.add(channelId);
			if (channelIds.size() > maxSessionsPerUser) {
				channelIds.remove(channelId);
				throw new RT2TypedException(ErrorCode.TOO_MANY_CONNECTIONS_ERROR, new ArrayList<>());
			}
		}
	}
	
	public void removeSession(RT2ChannelId channelId) {
		synchronized (countSessionsOfUser) {
			final Set<Integer> emptyEntries = new HashSet<>();
			countSessionsOfUser.entrySet().stream().forEach(p -> {				
				p.getValue().remove(channelId);
				if (p.getValue().isEmpty()) {
					emptyEntries.add(p.getKey());
				}
			});
			countSessionsOfUser.keySet().removeAll(emptyEntries);
		}		
	}
	
	public Map<Integer, Integer> getCountSessionsOfUsers() {
		Map<Integer, Integer> res = new HashMap<>();
		for (Map.Entry<Integer, Set<RT2ChannelId>> entry : countSessionsOfUser.entrySet()) {
			res.put(entry.getKey(), entry.getValue().size());
		}
		return res;
	}
}
