1import numpy as np
2from matplotlib import pyplot as plt
3
4from interlab.environment.base import BaseEnvironment
5
6
7class Monitor:
8 def __init__(self, env: BaseEnvironment):
9 self.env = env
10 self.traces: dict[str, tuple[list[any], list[any]]] = {}
11
12 def copy(self, new_env=None):
13 monitor = Monitor(new_env or self.env)
14 for key, value in self.traces.items():
15 monitor.traces[key] = (value[0][:], value[1][:])
16 return monitor
17
18 def get(self, name):
19 if name not in self.traces:
20 self.traces[name] = ([], [])
21 return self.traces[name]
22
23 def trace(self, name, value):
24 x, y = self.get(name)
25 x.append(self.env.steps)
26 y.append(value)
27
28 def line_chart(
29 self,
30 *,
31 names: list[str] | None = None,
32 colors: list[str] | None = None,
33 labels: list[str] | None = None,
34 legend_loc: str = "upper left",
35 cumsum: bool = False,
36 ):
37 fig = plt.figure()
38 plt.title("Payoffs")
39
40 if names is None:
41 names = list(self.traces)
42
43 for i, name in enumerate(names):
44 color = colors[i] if colors else None
45 label = labels[i] if labels else None
46 x, y = self.get(name)
47 if cumsum:
48 y = np.array(y).cumsum()
49 plt.plot(x, y, color=color, label=label)
50 plt.legend(loc=legend_loc)
51 return fig