"""
Based on
1. https://ax.dev/tutorials/tune_cnn.html
2. https://ax.dev/docs/api.html
3. https://ax.dev/tutorials/gpei_hartmann_service.html
4. https://ax.dev/tutorials/gpei_hartmann_loop.html
"""
from random import random
from ax.service.ax_client import AxClient
def f_to_minimize(w1, w2, w3):
w4 = 1 - (w1 + w2 + w3)
return (0.5 - w1)**2 + 1.2 * (w2 - w1**2)**2 + (1.5 - w3)**2 + 3.2 * (w4 - w3**2)**2 + (random() - 0.5) * 0.1
ax_client = AxClient()
ax_client.create_experiment(
name="ax_experiment",
parameters=[
{
"name": "w1",
"type": "range",
"bounds": [0.0, 1.0],
"value_type": "float",
},
{
"name": "w2",
"type": "range",
"bounds": [0.0, 1.0],
"value_type": "float",
},
{
"name": "w3",
"type": "range",
"bounds": [0.0, 1.0],
"value_type": "float",
},
],
parameter_constraints=["w1 + w2 + w3 <= 1.0"],
objective_name="f_to_minimize",
minimize=True,
)
for _ in range(25):
parameters, trial_index = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=trial_index,
raw_data=f_to_minimize(parameters["w1"], parameters["w2"], parameters["w3"]))
best_parameters, metrics = ax_client.get_best_parameters()