// input $10 for 1M tokens = 250000 or ~$2.50
// output $30 for 1M tokens = 83334 or ~$2.50
// to limit user cost to $5 max per month
// allow 250000 in request tokens and 83334 in response tokens

import { doc, setDoc, getDoc } from 'firebase/firestore';
import { encodingForModel } from 'js-tiktoken';

const encoding = encodingForModel('gpt-4-1106-preview');

const DEFAULT_USER_LIMIT = 5; //$5
const DEFAULT_TRIAL_REQUESTS = 30;

const requestCostPerToken = 0.00001;
const responseCostPerToken = 0.00003;

const isProduction = process.env.REACT_APP_ENDPOINT === '/ask';

export const DB_PATH = isProduction ? 'cost' : 'cost_development';

const date = new Date();
export const year = date.getFullYear();
export const month = date.toLocaleString('default', { month: 'long' });

const roundToFiveDec = (num) => {
  return Math.round(num * 100000) / 100000;
};

const calculateDollarCost = (type, tokenCost) => {
  const dollarCost =
    type === 'request'
      ? tokenCost * requestCostPerToken
      : tokenCost * responseCostPerToken;

  //round it up to 5 decimal places
  return roundToFiveDec(dollarCost);
};

const isUserLimitReached = (
  isPaidUser,
  tokenCostRequest,
  tokenCostResponse,
  monthRequestCost,
  monthResponseCost,
  monthlyUserLimit,
  trialRequests
) => {
  if (isPaidUser) {
    return (
      roundToFiveDec(
        calculateDollarCost('request', tokenCostRequest) +
          calculateDollarCost('response', tokenCostResponse) +
          monthRequestCost +
          monthResponseCost
      ) >= monthlyUserLimit
    );
  } else {
    return trialRequests - 1 <= 0;
  }
};

const determineUserType = (isPaidUser, email) => {
  if (isPaidUser === undefined) {
    return email.split('@')[1] === 'ryco.io' ? true : false;
  }
  return isPaidUser;
};

export const getTokenCost = (input, db, email) => {
  const tokenCostRequest = encoding.encode(input.request).length;
  const tokenCostResponse = encoding.encode(input.response).length;

  const setUserTokenUsage = async () => {
    try {
      const docRef = doc(db, DB_PATH, email);
      const docSnap = await getDoc(docRef);

      if (docSnap.exists()) {
        const docData = docSnap.data();
        const monthlyUserLimit = docData.userLimit;
        const isPaidUser = determineUserType(docData.isPaidUser, email);
        const trialRequests = docData?.trialRequests || DEFAULT_TRIAL_REQUESTS;

        const allTimeRequest =
          docData.allTime.request.tokens + tokenCostRequest;
        const allTimeResponse =
          docData.allTime.response.tokens + tokenCostResponse;
        const allTimeRequestCost = docData.allTime.request.cost;
        const allTimeResponseCost = docData.allTime.response.cost;
        const allTimeTotal = allTimeRequest + allTimeResponse;

        const yearRequest = docData[year].request.tokens + tokenCostRequest;
        const yearResponse = docData[year].response.tokens + tokenCostResponse;
        const yearRequestCost = docData[year].request.cost;
        const yearResponseCost = docData[year].response.cost;
        const yearTotal = yearRequest + yearResponse;

        const monthRequest =
          (docData[year][month]?.request?.tokens || 0) + tokenCostRequest;
        const monthResponse =
          (docData[year][month]?.response?.tokens || 0) + tokenCostResponse;
        const monthRequestCost = docData[year][month]?.request?.cost || 0;
        const monthResponseCost = docData[year][month]?.response?.cost || 0;
        const monthTotal = monthRequest + monthResponse;

        try {
          await setDoc(
            doc(db, DB_PATH, email),
            {
              isPaidUser,
              trialRequests: trialRequests - 1,
              allTime: {
                request: {
                  tokens: allTimeRequest,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      allTimeRequestCost
                  ),
                },
                response: {
                  tokens: allTimeResponse,
                  cost: roundToFiveDec(
                    calculateDollarCost('response', tokenCostResponse) +
                      allTimeResponseCost
                  ),
                },
                total: {
                  tokens: tokenCostRequest + tokenCostResponse + allTimeTotal,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      calculateDollarCost('response', tokenCostResponse) +
                      allTimeRequestCost +
                      allTimeResponseCost
                  ),
                },
              },
              [year]: {
                request: {
                  tokens: yearRequest,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      yearRequestCost
                  ),
                },
                response: {
                  tokens: yearResponse,
                  cost: roundToFiveDec(
                    calculateDollarCost('response', tokenCostResponse) +
                      yearResponseCost
                  ),
                },

                total: {
                  tokens: tokenCostRequest + tokenCostResponse + yearTotal,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      calculateDollarCost('response', tokenCostResponse) +
                      yearRequestCost +
                      yearResponseCost
                  ),
                },

                [month]: {
                  limitReached: isUserLimitReached(
                    isPaidUser,
                    tokenCostRequest,
                    tokenCostResponse,
                    monthRequestCost,
                    monthResponseCost,
                    monthlyUserLimit,
                    trialRequests
                  ),
                  request: {
                    tokens: monthRequest,
                    cost: roundToFiveDec(
                      calculateDollarCost('request', tokenCostRequest) +
                        monthRequestCost
                    ),
                  },
                  response: {
                    tokens: monthResponse,
                    cost: roundToFiveDec(
                      calculateDollarCost('response', tokenCostResponse) +
                        monthResponseCost
                    ),
                  },
                  total: {
                    tokens: tokenCostRequest + tokenCostResponse + monthTotal,
                    cost: roundToFiveDec(
                      calculateDollarCost('request', tokenCostRequest) +
                        calculateDollarCost('response', tokenCostResponse) +
                        monthRequestCost +
                        monthResponseCost
                    ),
                  },
                },
              },
            },
            { merge: true }
          );
        } catch (e) {
          console.error('Error adding document: ', e);
        }
      } else {
        console.log('No such document!');
        //create the document
        try {
          await setDoc(
            doc(db, DB_PATH, email),
            {
              unlimited: false,
              userLimit: DEFAULT_USER_LIMIT,
              isPaidUser: determineUserType(undefined, email),
              trialRequests: DEFAULT_TRIAL_REQUESTS - 1,
              allTime: {
                request: {
                  tokens: tokenCostRequest,
                  cost: calculateDollarCost('request', tokenCostRequest),
                },
                response: {
                  tokens: tokenCostResponse,
                  cost: calculateDollarCost('response', tokenCostResponse),
                },
                total: {
                  tokens: tokenCostRequest + tokenCostResponse,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      calculateDollarCost('response', tokenCostResponse)
                  ),
                },
              },
              [year]: {
                request: {
                  tokens: tokenCostRequest,
                  cost: calculateDollarCost('request', tokenCostRequest),
                },
                response: {
                  tokens: tokenCostResponse,
                  cost: calculateDollarCost('response', tokenCostResponse),
                },
                total: {
                  tokens: tokenCostRequest + tokenCostResponse,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      calculateDollarCost('response', tokenCostResponse)
                  ),
                },
                [month]: {
                  limitReached: false,
                  request: {
                    tokens: tokenCostRequest,
                    cost: calculateDollarCost('request', tokenCostRequest),
                  },
                  response: {
                    tokens: tokenCostResponse,
                    cost: calculateDollarCost('response', tokenCostResponse),
                  },
                  total: {
                    tokens: tokenCostRequest + tokenCostResponse,
                    cost: roundToFiveDec(
                      calculateDollarCost('request', tokenCostRequest) +
                        calculateDollarCost('response', tokenCostResponse)
                    ),
                  },
                },
              },
            },
            { merge: true }
          );
        } catch (e) {
          console.error('Error adding document: ', e);
        }
      }
    } catch (e) {
      console.error('Error getting document: ', e);
    }
  };

  const setOverallTokenUsage = async () => {
    try {
      const docRef = doc(db, DB_PATH, 'overall');
      const docSnap = await getDoc(docRef);

      if (docSnap.exists()) {
        const docData = docSnap.data();
        const allTimeRequest =
          docData.allTime.request.tokens + tokenCostRequest;
        const allTimeResponse =
          docData.allTime.response.tokens + tokenCostResponse;
        const allTimeRequestCost = docData.allTime.request.cost;
        const allTimeResponseCost = docData.allTime.response.cost;
        const allTimeTotal = allTimeRequest + allTimeResponse;

        const yearRequest = docData[year].request.tokens + tokenCostRequest;
        const yearResponse = docData[year].response.tokens + tokenCostResponse;
        const yearRequestCost = docData[year].request.cost;
        const yearResponseCost = docData[year].response.cost;
        const yearTotal = yearRequest + yearResponse;

        const monthRequest =
          (docData[year][month]?.request?.tokens || 0) + tokenCostRequest;
        const monthResponse =
          (docData[year][month]?.response?.tokens || 0) + tokenCostResponse;
        const monthRequestCost = docData[year][month]?.request?.cost || 0;
        const monthResponseCost = docData[year][month]?.response?.cost || 0;
        const monthTotal = monthRequest + monthResponse;

        try {
          await setDoc(
            doc(db, DB_PATH, 'overall'),
            {
              allTime: {
                request: {
                  tokens: allTimeRequest,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      allTimeRequestCost
                  ),
                },
                response: {
                  tokens: allTimeResponse,
                  cost: roundToFiveDec(
                    calculateDollarCost('response', tokenCostResponse) +
                      allTimeResponseCost
                  ),
                },
                total: {
                  tokens: allTimeTotal,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      calculateDollarCost('response', tokenCostResponse) +
                      allTimeRequestCost +
                      allTimeResponseCost
                  ),
                },
              },
              [year]: {
                request: {
                  tokens: yearRequest,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      yearRequestCost
                  ),
                },
                response: {
                  tokens: yearResponse,
                  cost: roundToFiveDec(
                    calculateDollarCost('response', tokenCostResponse) +
                      yearResponseCost
                  ),
                },
                total: {
                  tokens: tokenCostRequest + tokenCostResponse + yearTotal,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      calculateDollarCost('response', tokenCostResponse) +
                      yearRequestCost +
                      yearResponseCost
                  ),
                },
                [month]: {
                  request: {
                    tokens: monthRequest,
                    cost: roundToFiveDec(
                      calculateDollarCost('request', tokenCostRequest) +
                        monthRequestCost
                    ),
                  },
                  response: {
                    tokens: monthResponse,
                    cost: roundToFiveDec(
                      calculateDollarCost('response', tokenCostResponse) +
                        monthResponseCost
                    ),
                  },
                  total: {
                    tokens: tokenCostRequest + tokenCostResponse + monthTotal,
                    cost: roundToFiveDec(
                      calculateDollarCost('request', tokenCostRequest) +
                        calculateDollarCost('response', tokenCostResponse) +
                        monthRequestCost +
                        monthResponseCost
                    ),
                  },
                },
              },
            },
            { merge: true }
          );
        } catch (e) {
          console.error('Error adding document: ', e);
        }
      } else {
        //create the doc
        try {
          await setDoc(
            doc(db, DB_PATH, 'overall'),
            {
              allTime: {
                request: {
                  tokens: tokenCostRequest,
                  cost: calculateDollarCost('request', tokenCostRequest),
                },
                response: {
                  tokens: tokenCostResponse,
                  cost: calculateDollarCost('response', tokenCostResponse),
                },
                total: {
                  tokens: tokenCostRequest + tokenCostResponse,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      calculateDollarCost('response', tokenCostResponse)
                  ),
                },
              },
              [year]: {
                request: {
                  tokens: tokenCostRequest,
                  cost: calculateDollarCost('request', tokenCostRequest),
                },
                response: {
                  tokens: tokenCostResponse,
                  cost: calculateDollarCost('response', tokenCostResponse),
                },
                total: {
                  tokens: tokenCostRequest + tokenCostResponse,
                  cost: roundToFiveDec(
                    calculateDollarCost('request', tokenCostRequest) +
                      calculateDollarCost('response', tokenCostResponse)
                  ),
                },
                [month]: {
                  request: {
                    tokens: tokenCostRequest,
                    cost: calculateDollarCost('request', tokenCostRequest),
                  },
                  response: {
                    tokens: tokenCostResponse,
                    cost: calculateDollarCost('response', tokenCostResponse),
                  },
                  total: {
                    tokens: tokenCostRequest + tokenCostResponse,
                    cost: roundToFiveDec(
                      calculateDollarCost('request', tokenCostRequest) +
                        calculateDollarCost('response', tokenCostResponse)
                    ),
                  },
                },
              },
            },
            { merge: true }
          );
        } catch (e) {
          console.error('Error adding document: ', e);
        }
      }
    } catch (e) {}
  };

  setUserTokenUsage();
  setOverallTokenUsage();
};
