interlab.environment.experimental.monitor

 1import numpy as np
 2from matplotlib import pyplot as plt
 3
 4from interlab.environment.base import BaseEnvironment
 5
 6
 7class Monitor:
 8    def __init__(self, env: BaseEnvironment):
 9        self.env = env
10        self.traces: dict[str, tuple[list[any], list[any]]] = {}
11
12    def copy(self, new_env=None):
13        monitor = Monitor(new_env or self.env)
14        for key, value in self.traces.items():
15            monitor.traces[key] = (value[0][:], value[1][:])
16        return monitor
17
18    def get(self, name):
19        if name not in self.traces:
20            self.traces[name] = ([], [])
21        return self.traces[name]
22
23    def trace(self, name, value):
24        x, y = self.get(name)
25        x.append(self.env.steps)
26        y.append(value)
27
28    def line_chart(
29        self,
30        *,
31        names: list[str] | None = None,
32        colors: list[str] | None = None,
33        labels: list[str] | None = None,
34        legend_loc: str = "upper left",
35        cumsum: bool = False,
36    ):
37        fig = plt.figure()
38        plt.title("Payoffs")
39
40        if names is None:
41            names = list(self.traces)
42
43        for i, name in enumerate(names):
44            color = colors[i] if colors else None
45            label = labels[i] if labels else None
46            x, y = self.get(name)
47            if cumsum:
48                y = np.array(y).cumsum()
49            plt.plot(x, y, color=color, label=label)
50        plt.legend(loc=legend_loc)
51        return fig