-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
37 lines (28 loc) · 1.03 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import jax.numpy as jnp
from qdax.types import Metrics
from typing import Protocol, TYPE_CHECKING
if TYPE_CHECKING:
from .containers import ExtendedRepertoire
class MetricsFn(Protocol):
def __call__(self, repertoire: 'ExtendedRepertoire') -> Metrics:
...
def qd_metrics(
repertoire: 'ExtendedRepertoire',
qd_offset: float,
) -> Metrics:
# get metrics
repertoire_empty = repertoire.fitnesses == -jnp.inf
qd_score = jnp.sum(repertoire.fitnesses, where=~repertoire_empty)
qd_score += qd_offset * jnp.sum(1.0 - repertoire_empty)
coverage = 100 * jnp.mean(1.0 - repertoire_empty)
max_fitness = jnp.max(repertoire.fitnesses)
min_fitness = jnp.min(repertoire.fitnesses, initial=max_fitness, where=~repertoire_empty)
mean_fitness = jnp.mean(repertoire.fitnesses, where=~repertoire_empty)
metrics = {
'qd_score': qd_score,
'max_fitness': max_fitness,
'coverage': coverage,
'min_fitness': min_fitness,
'mean_fitness': mean_fitness,
}
return metrics