import { getAllDefinedValues } from "../Constants/systemSetting";

export function getLinearCorrelationData(
  scatterPlotData,
  xAxis = "measure",
  yAxis = "compensation"
) {
  const getMean = (values) => {
    return (
      values.reduce((acc, cur) => acc + parseFloat(cur), 0) / values.length
    );
  };
  const extendValue = (value, type) => {
    const extensionFactor = 0.1;
    if ((value > 0 && type === "max") || (value < 0 && type === "min")) {
      return value * (1 + extensionFactor);
    } else if ((value < 0 && type === "max") || (value > 0 && type === "min")) {
      return value / (1 + extensionFactor);
    }
    return 0;
  };
  const xValues = getAllDefinedValues(scatterPlotData, xAxis);
  const yValues = getAllDefinedValues(scatterPlotData, yAxis);
  const datapoints = Math.min(xValues.length, yValues.length);
  // console.log(xValues, yValues);
  const xMean = getMean(xValues);
  const yMean = getMean(yValues);
  let slopeNumerator = 0;
  let slopDenominator = 0;
  xValues.forEach((xVal, idx) => {
    slopeNumerator += (xVal - xMean) * (yValues[idx] - yMean);
    slopDenominator += Math.pow(xVal - xMean, 2);
  });
  const slope = slopeNumerator / slopDenominator;
  const yIntercept = yMean - xMean * slope;
  // console.log(slope, yIntercept);
  const maxXValue = Math.max(...xValues);
  const minXValue = Math.min(...xValues);
  // Line should extend beyond the last datapoint (+/-):
  const maxXValuePlusExtra = extendValue(maxXValue, "max");
  const minXValueMinusExtra = extendValue(minXValue, "min");
  const getYValue = (xValue) => slope * xValue + yIntercept;
  const linePlotData = [
    {
      measure: minXValueMinusExtra,
      correlation: getYValue(minXValueMinusExtra),
    },
    {
      measure: maxXValuePlusExtra,
      correlation: getYValue(maxXValuePlusExtra),
    },
  ];

  // yHat is calculated y values, to be used to calculate difference for RSquared
  const yHat = xValues.map((xVal) => getYValue(xVal));
  let RSS = 0; // Residual sum of squares
  let TSS = 0; // Total sum of squares
  yValues.forEach((yVal, idx) => {
    RSS += Math.pow(yHat[idx] - yVal, 2);
    TSS += Math.pow(yVal - yMean, 2);
  });
  const RSquared = 1 - RSS / TSS;

  // Calculate correlation coefficient (CORREL)
  const numerator = xValues.reduce(
    (acc, xVal, index) => acc + (xVal - xMean) * (yValues[index] - yMean),
    0
  );
  // xTerm = Summation (x - xMean)^2
  const xTerm = xValues.reduce(
    (acc, xVal) => acc + Math.pow(xVal - xMean, 2),
    0
  );
  const yTerm = yValues.reduce(
    (acc, yVal) => acc + Math.pow(yVal - yMean, 2),
    0
  );
  const denominator = Math.sqrt(xTerm * yTerm);
  const correlation = numerator / denominator;
  return { slope, yIntercept, RSquared, linePlotData, correlation, datapoints };
}
