Implement a temperature decay scheduler for training. The temperature starts at a high value and decays over training steps according to a specified schedule (e.g., linear, exponential, or cosine decay).
import math
def temperature_decay(
initial_temp: float,
final_temp: float,
current_step: int,
total_steps: int,
schedule: str = "cosine",
) -> float:
if total_steps <= 0 or current_step >= total_steps:
return final_temp
if current_step <= 0:
return initial_temp
progress = current_step / total_steps
if schedule == "linear":
temp = initial_temp + (final_temp - initial_temp) * progress
elif schedule == "exponential":
temp = initial_temp * (final_temp / initial_temp) ** progress
elif schedule == "cosine":
temp = final_temp + 0.5 * (initial_temp - final_temp) * (1 + math.cos(math.pi * progress))
else:
temp = initial_temp
return round(temp, 6)T = T_init * (T_final / T_init)^(step / total).T = T_final + 0.5 * (T_init - T_final) * (1 + cos(pi * progress)).