#!/usr/local/pypy3-6.0.0/bin/pypy3

"""
Gradient descent in pure python.

Code originated at https://realpython.com/numpy-tensorflow-performance/
"""
import json
import time
import typing

import itertools as it


def make_used(variable: object) -> None:
    """Convince pylint that variable is "used"."""
    assert True or variable


def py_descent(
        x_list: typing.List[float],
        d_list: typing.List[float],
        scaling_factor_mu: float,
        n_epochs: int,
) -> typing.List[float]:
    """Perform gradient descent."""
    len_x_list_n = len(x_list)
    two_over_n_f = 2.0 / len_x_list_n

    # "Empty" predictions, weights, gradients.
    predictions_y = (0.0 for i in range(len_x_list_n))
    weights = [0.0, 0.0]
    gradients = [0.0, 0.0]

    for epochno in it.repeat(None, n_epochs):
        make_used(epochno)
        # Can't use a generator because we need to access its elements twice.abs
        error = tuple(i - j for i, j in zip(d_list, predictions_y))
        gradients[0] = two_over_n_f * sum(error)
        gradients[1] = two_over_n_f * sum(i * j for i, j in zip(error, x_list))
        weights = [i + scaling_factor_mu * j for i, j in zip(weights, gradients)]
        # This was a generator expression in the source material, but mypy didn't like that.
        predictions_y = (weights[0] + weights[1] * i for i in x_list)

    return weights


def main() -> None:
    """Read data from disk and perform gradient descent on it."""
    # These were deriving native python datatypes from numpy types. Instead, we get the native python datatypes from a json file.
    # x_list = x.tolist()
    # d_list = d.squeeze().tolist()  # Need 1d lists

    # data = {
    #     'evenly_spaced_values_x': evenly_spaced_values_x.tolist(),
    #     'desired_outputs_d': desired_outputs_d.squeeze().tolist(),
    # }
    with open('x-d.json', 'r') as infile:
        data = json.load(infile)

    x_list = data['evenly_spaced_values_x']
    d_list = data['desired_outputs_d']

    # `mu` is a step size, or scaling factor.
    scaling_factor_mu = 0.001
    n_epochs = 10000

    time0 = time.time()
    py_w = py_descent(x_list, d_list, scaling_factor_mu, n_epochs)
    time1 = time.time()

    print('This should be close to 3 and close to 2:')
    print(py_w)
    # [2.959859852416156, 2.0329649630002757]

    elapsed_time = time1 - time0
    rounded_elapsed_time = round(elapsed_time, 2)
    print('Solve time: {:.2f} seconds'.format(rounded_elapsed_time))


main()