"""
Based on
1. https://ax.dev/tutorials/tune_cnn.html
2. https://ax.dev/docs/api.html
3. https://ax.dev/tutorials/gpei_hartmann_service.html
4. https://ax.dev/tutorials/gpei_hartmann_loop.html
"""
from random import random
from ax.service.ax_client import AxClient


def f_to_minimize(w1, w2, w3):
    w4 = 1 - (w1 + w2 + w3)
    return (0.5 - w1)**2 + 1.2 *  (w2 - w1**2)**2 + (1.5 - w3)**2 + 3.2 *  (w4 - w3**2)**2 + (random() - 0.5) * 0.1

ax_client = AxClient()
ax_client.create_experiment(
    name="ax_experiment",
    parameters=[
        {
            "name": "w1",
            "type": "range",
            "bounds": [0.0, 1.0],
            "value_type": "float",
        },
        {
            "name": "w2",
            "type": "range",
            "bounds": [0.0, 1.0],
            "value_type": "float",
        },
        {
            "name": "w3",
            "type": "range",
            "bounds": [0.0, 1.0],
            "value_type": "float",
        },
    ],
    parameter_constraints=["w1 + w2 + w3 <= 1.0"],
    objective_name="f_to_minimize",
    minimize=True,
)

for _ in range(25):
    parameters, trial_index = ax_client.get_next_trial()
    ax_client.complete_trial(
        trial_index=trial_index,
        raw_data=f_to_minimize(parameters["w1"], parameters["w2"], parameters["w3"]))

best_parameters, metrics = ax_client.get_best_parameters()