import { useCallback, useEffect, useRef, useState } from "react";

import { IUseWebSocket } from "./types/websocket";
import { useLatestJwtToken } from "../auth/auth-store-context";
import { PlainError } from "../components/events/plain-error";
import { useStore } from "../models/helpers";
import { RootStore } from "../models/root";

export const useWebSocket = <T extends object, U extends object>({
  url,
  query,
  onError,
  onResponse,
}: {
  url: string;
  query?: Record<string, string>;
  onError?: () => Promise<void> | void;
  onResponse?: (event?: MessageEvent<U>) => Promise<void> | void;
}): IUseWebSocket<T> => {
  const [isReady, setIsReady] = useState(false);

  const socketRef = useRef<WebSocket | null>(null);
  const token = useLatestJwtToken();
  const store = useStore();

  useEffect(() => {
    if (!socketRef.current && token) {
      const queryParams = new URLSearchParams({ token });

      Object.entries(query || {}).forEach(([key, value]) => {
        queryParams.append(key, value);
      });

      const fullUrl = `${url}?${queryParams.toString()}`;
      const socket = new WebSocket(fullUrl);

      socket.onmessage = (event) => {
        extractResponseError(event, store);
        onResponse?.(JSON.parse(event.data));
      };

      socket.onclose = () => {
        socketRef.current = null;
        setIsReady(false);
      };

      socket.onerror = (event) => {
        setIsReady(false);
        console.log("WebSocket error", event);
        onError?.();
      };

      socket.onopen = () => {
        setIsReady(true);
      };

      socketRef.current = socket;
    }
  }, [onError, onResponse, query, store, token, url]);

  const sendMessage = useCallback((request: T) => {
    socketRef.current?.send(JSON.stringify(request));
  }, []);

  return { sendMessage, webSocket: socketRef.current, isReady };
};

const extractResponseError = (event: MessageEvent, store: RootStore): void => {
  const data = JSON.parse(event.data);
  if (data.kind === "error" && data.error) {
    store.addToastEvent(new PlainError({ tx: data.error }));
  }
};
