import WS from 'isomorphic-ws';
import { Subject } from 'rxjs';
import { dlog } from '../dev';

/**
 * Logic for websocket with a shared instance:
 * - Listeners with a common ws-key share a ws instance.
 * - Different ws instances can point at the same url if necessary (NOT IMPLEMENTED YET : TODO)
 * - Each ws creates a multiplexed subject which subscribers can listen to. (they need to filter messages for their own purposes)
 * - Websockets have a timeout which is used to detect when messages have stopped coming through.
 * - When a websocket times out it:
 * 	- increments its reconnect attempts
 * 		- calls the listeners error handlers if it exceeds its maximum retry attempts
 * 		- restarts the connections for all of its listeners if it doesn't exceed the maximum retry attempts
 * - When a websocket receives a message it:
 * 	- Resets its timeout
 * 	- Forwards the message to all listeners
 *
 * - Listeners need a:
 * 	- Websocket URL
 * 	- Subscribing message callback
 * 	- Handle Error callback
 * 	- ? Unsubscribing message callback // TODO - assume never unsubscribe just
 * 	- ? Handle Disconnect callback // TODO - assume error for now
 *
 * Each websockets reconnection attempts counter also decays so that healthy connections don't inevitably fall back from odd quiet periods
 */

// The websocket will try to reconnect if there is a larger gap than the reconnection delay between messages. At the minimum we expect a heartbeat from the ws server every 5 seconds. Double it to give breathing room.
const DEFAULT_RECONNECTION_DELAY_MS = 10_000;
const DEFAULT_MAX_RECONNECTION_DELAY_MS = 10_000;
// Times that the websocket will allow itself to reconnect before throwing an error (when we assume things are not working).
const DEFAULT_MAX_RECONNECTION_ATTEMPTS = 3;
// Time allowed for the first message from the websocket
const FIRST_CONNECTION_SAFE_TIMEOUT_MS = 2_000;
const DEFAULT_RECONNECT_COUNT_DECAY_MS = DEFAULT_MAX_RECONNECTION_DELAY_MS * 3;

const WS_KILLING_DELAY = 3_000;

const DEV_LOGGING = false;
const devLog = (message: any) => {
	if (DEV_LOGGING) {
		dlog(`websocket_debugging`, message);
	}
};

export type WebSocketMessage = {
	channel: string;
	data: any;
};

type ListenerState = {
	subscribeMessage: string;
	unsubscribeMessage: string;
	onError: (err?: any) => void;
};

enum WEBSOCKET_CONNECTION_STATE {
	DISCONNECTED,
	CONNECTED,
	GOING_TO_CONNECT,
	GOING_TO_DISCONNECT,
}

type WebsocketState = {
	ws: WS;
	connectionState: WEBSOCKET_CONNECTION_STATE;
	reconnectAttempts: number;
	reconnectCountDecayInterval: ReturnType<typeof setInterval>;
	timeout: ReturnType<typeof setTimeout>;
	reconnectDelayMs: number;
	maxReconnectionAttempts: number;
	subject: Subject<WebSocketMessage>;
	listeners: Map<string, ListenerState>;
};

type ListenerProps = {
	wsUrl: string;
	listenerId: string;
	subscribeMessage: string;
	unsubscribeMessage: string;
	onError: (err?: any) => void;
};

const DEFAULT_WS_STATE = {
	reconnectAttempts: 0,
	timeout: undefined,
	reconnectDelayMs: DEFAULT_RECONNECTION_DELAY_MS,
	maxReconnectionAttempts: DEFAULT_MAX_RECONNECTION_ATTEMPTS,
};

class WebsocketUtilClass {
	private static wsStateLookup = new Map<string, WebsocketState>();

	private static addNewListenerToMap(
		props: ListenerProps,
		listenerMap: Map<string, ListenerState>
	) {
		listenerMap.set(props.listenerId, {
			onError: props.onError,
			subscribeMessage: props.subscribeMessage,
			unsubscribeMessage: props.unsubscribeMessage,
		});
	}

	private static handleReconnectCountDecay(wsUrl: string) {
		devLog('decaying reconnectAttempt count');
		const wsState = this.wsStateLookup.get(wsUrl);
		if (wsState.reconnectAttempts > 0) {
			wsState.reconnectAttempts--;
		}
	}

	private static cleanupWs(wsUrl: string) {
		const wsState = this.wsStateLookup.get(wsUrl);
		if (!wsState) {
			return;
		}
		wsState.ws.close();
		clearTimeout(wsState.timeout);
		clearInterval(wsState.reconnectCountDecayInterval);
		this.wsStateLookup.delete(wsUrl);
	}

	private static createNewWsState(props: ListenerProps) {
		devLog(`createNewWsState`);
		if (this.wsStateLookup.get(props.wsUrl)) {
			throw new Error('Tried to override existing ws state');
		}

		const listenerMap = new Map<string, ListenerState>();
		this.addNewListenerToMap(props, listenerMap);

		const newWs = new WS(props.wsUrl);

		const newWsState: WebsocketState = {
			ws: newWs,
			connectionState: WEBSOCKET_CONNECTION_STATE.GOING_TO_CONNECT,
			reconnectAttempts: DEFAULT_WS_STATE.reconnectAttempts,
			reconnectCountDecayInterval: setInterval(() => {
				this.handleReconnectCountDecay(props.wsUrl);
			}, DEFAULT_RECONNECT_COUNT_DECAY_MS),
			timeout: DEFAULT_WS_STATE.timeout,
			reconnectDelayMs: DEFAULT_WS_STATE.reconnectDelayMs,
			maxReconnectionAttempts: DEFAULT_WS_STATE.maxReconnectionAttempts,
			listeners: listenerMap,
			subject: new Subject<WebSocketMessage>(),
		};

		this.wsStateLookup.set(props.wsUrl, newWsState);
	}

	private static handleWebsocketStateForNewListener(props: ListenerProps): {
		isNewWs: boolean;
	} {
		const currentWsState = this.wsStateLookup.get(props.wsUrl);

		if (currentWsState) {
			if (currentWsState.listeners.get(props.listenerId)) {
				throw new Error(
					'Trying to subscribe two listeners with the same ID to the same websocket'
				);
			} else {
				this.addNewListenerToMap(props, currentWsState.listeners);
			}
			return { isNewWs: false };
		} else {
			this.createNewWsState(props);
			return { isNewWs: true };
		}
	}

	private static startListenerSubscription(wsUrl: string, listenerId: string) {
		const wsState = this.wsStateLookup.get(wsUrl);
		const ws = wsState.ws;

		if (wsState && ws.readyState === ws.OPEN) {
			const listnerState = wsState.listeners.get(listenerId);
			ws.send(listnerState.subscribeMessage);
		} else {
			dlog(
				`websocket_debugging`,
				`caught_new_ws_listener_but_connection_not_open`
			);
			this.refreshConnection(wsUrl);
		}
	}

	private static clearTimeout(wsUrl: string) {
		const wsState = this.wsStateLookup.get(wsUrl);
		clearTimeout(wsState.timeout);
	}

	private static removeListener(
		wsUrl: string,
		listenerId: string,
		context?: string
	) {
		dlog(
			`websocket_debugging`,
			`removing_ws_listener ${
				context ? `.. context(${context}) ..` : ''
			} ${listenerId}`
		);

		const wsState = this.wsStateLookup.get(wsUrl);

		if (wsState) {
			const listener = wsState.listeners.get(listenerId);
			wsState.listeners.delete(listenerId);

			if (wsState.ws.readyState === wsState.ws.OPEN && listener) {
				wsState.ws.send(listener.unsubscribeMessage);
			}

			if (wsState.listeners.size === 0) {
				setTimeout(() => {
					if (wsState.listeners.size === 0) {
						dlog(
							`websocket_debugging`,
							`killing_ws_instance ${wsUrl} after removing last listener`
						);
						this.cleanupWs(wsUrl);
					} else {
						dlog(
							`websocket_debugging`,
							`NOT_killing_ws_instance ${wsUrl} .. would have killed ws because we removed the last listener, but a new listener was added in the meantime`
						);
					}
				}, WS_KILLING_DELAY);
			}
		}
	}

	private static handleExceededMaxReconnectAttempts(wsUrl: string) {
		devLog('handleExceededMaxReconnectAttempts');
		const wsState = this.wsStateLookup.get(wsUrl);
		if (wsState) {
			for (const [listenerId, listener] of wsState.listeners.entries()) {
				this.removeListener(
					wsUrl,
					listenerId,
					'handleExceededMaxReconnectAttempts'
				);
				listener.onError();
			}
		}
	}

	private static refreshConnection(wsUrl: string) {
		devLog('refreshConnection');
		const wsState = this.wsStateLookup.get(wsUrl);
		const ws = wsState.ws;

		// Close the previous connection
		if (ws.readyState === ws.OPEN) {
			ws.close();
		} else {
			devLog('handling refresh connection before onopen event fired');
			ws.onopen = () => {
				ws.onclose = () => {}; // Remove onclose event listener
				ws.close();
			};
		}

		// Create a new WS
		const newWs = new WS(wsUrl);

		// Add WS to state
		wsState.ws = newWs;

		// Link the new WS to the existing listeners
		this.handleNewWs(wsUrl);
	}

	private static handleNoMessageTimeout(wsUrl: string) {
		const wsState = this.wsStateLookup.get(wsUrl);

		devLog('handleMessageTimeout');

		if (
			wsState.connectionState === WEBSOCKET_CONNECTION_STATE.DISCONNECTED ||
			wsState.connectionState === WEBSOCKET_CONNECTION_STATE.GOING_TO_DISCONNECT
		) {
			devLog(
				'skipping timeout handling because already disconnected|disconecting'
			);
			return;
		}

		this.clearTimeout(wsUrl);

		wsState.reconnectAttempts++;
		if (wsState.reconnectAttempts >= wsState.maxReconnectionAttempts) {
			wsState.connectionState = WEBSOCKET_CONNECTION_STATE.GOING_TO_DISCONNECT;
			this.handleExceededMaxReconnectAttempts(wsUrl);
		} else {
			wsState.reconnectDelayMs = Math.min(
				wsState.reconnectDelayMs * 2,
				DEFAULT_MAX_RECONNECTION_DELAY_MS
			);
			this.refreshConnection(wsUrl);
			this.restartNoMessageTimeout(wsUrl);
		}
	}

	private static restartNoMessageTimeout(
		wsUrl: string,
		firstTimeout?: boolean
	) {
		this.clearTimeout(wsUrl);
		const wsState = this.wsStateLookup.get(wsUrl);

		if (
			wsState.connectionState === WEBSOCKET_CONNECTION_STATE.GOING_TO_DISCONNECT
		) {
			devLog('skipping restartNoMessageTimeout because already disconnecting');
			return;
		}

		devLog('restartNoMessageTimeout');

		// Use a different timeout for the first connection because we know it can be a little bit slower
		const timeoutDelay = firstTimeout
			? FIRST_CONNECTION_SAFE_TIMEOUT_MS
			: wsState.reconnectDelayMs;
		wsState.timeout = setTimeout(() => {
			this.handleNoMessageTimeout(wsUrl);
		}, timeoutDelay);
	}

	/**
	 * Send subscription message for all relevant listeners
	 * @param wsUrl
	 */
	private static handleWsConnected(wsUrl: string) {
		const wsState = this.wsStateLookup.get(wsUrl);
		if (wsState) {
			for (const [listenerId, _listenerState] of wsState.listeners) {
				this.startListenerSubscription(wsUrl, listenerId);
			}
		}
	}

	private static handleNewWs(wsUrl: string) {
		const wsState = this.wsStateLookup.get(wsUrl);
		const ws = wsState.ws;

		ws.onopen = (_event) => {
			devLog('onopen');
			wsState.connectionState = WEBSOCKET_CONNECTION_STATE.CONNECTED;
			this.handleWsConnected(wsUrl);
		};

		ws.onmessage = (incoming) => {
			devLog('onmessage');
			this.restartNoMessageTimeout(wsUrl);

			// Forward message to all observers
			const messageData = incoming.data as string;
			const message = JSON.parse(messageData) as WebSocketMessage;
			wsState.subject.next(message);
		};

		ws.onclose = (event) => {
			dlog(`websocket_debugging`, `onclose ${event.code} ${event.reason}`);
			wsState.connectionState = WEBSOCKET_CONNECTION_STATE.DISCONNECTED;
			devLog('onclose');
		};

		this.restartNoMessageTimeout(wsUrl, true);
	}

	public static createWebsocketListener(props: ListenerProps): {
		unsubscribe: () => void;
		subject: Subject<WebSocketMessage>;
	} {
		dlog(`websocket_debugging`, `creating_ws_listener ${props.listenerId}`);

		const { isNewWs } = this.handleWebsocketStateForNewListener(props);

		if (isNewWs) {
			this.handleNewWs(props.wsUrl);
		}

		const wsState = this.wsStateLookup.get(props.wsUrl);
		const subject = wsState.subject;

		if (!isNewWs) {
			// Only fire this immediately if it's an existing ws. If it's a new WS then the listener will start once the ws has finished connecting.
			this.startListenerSubscription(props.wsUrl, props.listenerId);
		}

		return {
			unsubscribe: () => {
				devLog(`unsubscribe callback`);
				this.removeListener(
					props.wsUrl,
					props.listenerId,
					'unsubscribe callback'
				);
			},
			subject,
		};
	}
}

export default WebsocketUtilClass;
