From ec181c08399374b8541cf3725d9901d498540661 Mon Sep 17 00:00:00 2001 From: Pat Callaghan Date: Wed, 15 Jun 2022 20:22:43 -0400 Subject: [PATCH] New arguments and argument handling to specify 0,1,2 demo count for DemPref. --- inquire/environments/lunar_lander.py | 4 +- tests/args_handler.py | 447 ++++++++++++++++++++++----- tests/corl22.py | 37 ++- tests/evaluation.py | 173 ++++++++--- 4 files changed, 528 insertions(+), 133 deletions(-) diff --git a/inquire/environments/lunar_lander.py b/inquire/environments/lunar_lander.py index 87af587..416a435 100644 --- a/inquire/environments/lunar_lander.py +++ b/inquire/environments/lunar_lander.py @@ -109,8 +109,8 @@ def generate_random_state(self, random_state): def generate_random_reward(self, random_state): """Randomly generate a weight vector for trajectory features.""" # reward = np.array([-0.4, 0.4, -0.2, -0.7]) - # reward = np.array([0.55, 0.55, 0.41, 0.48]) - reward = np.random.uniform(-1,1,4) + reward = np.array([0.55, 0.55, 0.41, 0.48]) + # reward = np.random.uniform(-1, 1, 4) return reward / np.linalg.norm(reward) def reset(self) -> np.ndarray: diff --git a/tests/args_handler.py b/tests/args_handler.py index 44979ca..5a49ea7 100644 --- a/tests/args_handler.py +++ b/tests/args_handler.py @@ -3,62 +3,212 @@ import os import pdb import time -import numpy as np +import numpy as np from inquire.utils.datatypes import Modality from inquire.utils.sampling import TrajectorySampling -class ArgsHandler(): + +class ArgsHandler: def __init__(self): - parser = argparse.ArgumentParser(description='Parameters for evaluating INQUIRE') - parser.add_argument("-V", "--verbose", dest='verbose', action='store_true', - help='verbose') - parser.add_argument("--use_cache", dest='use_cache', action='store_true', - help='use cached trajectories instead of sampling') - parser.add_argument("--numba", dest='use_numba', action='store_true', - help='use cached trajectories instead of sampling') - parser.add_argument("--static_state", dest='static_state', action='store_true', - help='use the same state for all queries') - parser.add_argument("--teacher_displays", action="store_true", - help="display the teacher's interactions.") - parser.add_argument("-K", "--queries", type=int, dest='num_queries', default=5, - help='number of queries') - parser.add_argument("-R", "--runs", type=int, dest='num_runs', default=10, - help='number of evaluations to run for each task') - parser.add_argument("-X", "--tests", type=int, dest='num_test_states', default=1, - help='number of test states to evaluate') - parser.add_argument("-Z", "--tasks", type=int, dest='num_tasks', default=1, - help='number of task instances to generate') - parser.add_argument("--beta", type=float, dest='beta', default=1.0, - help='optimality parameter') - parser.add_argument("--betas", type=str, dest='beta_vals', - help='optimality parameter, set for each interaction type') - parser.add_argument("--costs", type=str, dest='cost_vals', - help='cost per interaction, set for each interaction type') - parser.add_argument("-C", "--convergence_threshold", type=float, dest='conv_threshold', default=1.0e-1, - help='convergence threshold for estimating weight distribution') - parser.add_argument("--alpha", type=float, dest='step_size', default=0.05, - help='step size for weight optimization') - parser.add_argument("-M", type=int, dest='num_w_samples', default=100, - help='number of weight samples') - parser.add_argument("-N", type=int, dest='num_traj_samples', default=50, - help='number of trajectory samples') - parser.add_argument("-D", "--domain", type=str, dest='domain_name', default="linear_combo", choices=["lander", "linear_combo", "linear_system", "pats_linear_system", "pizza"], - help='name of the evaluation domain') - parser.add_argument("-I", "--opt_iterations", type=int, dest='opt_iters', default=50, - help='number of attempts to optimize a sample of controls (pertinent to lunar lander, linear system, and pizza-making domains)') - parser.add_argument("-S", "--sampling", type=str, dest='sampling_method', default="uniform", choices=["uniform"], - help='name of the trajectory sampling method') - parser.add_argument("-A", "--agent", type=str, dest='agent_name', default="inquire", choices=["inquire", "dempref", "biased_dempref", "no-demos", "demo-only", "pref-only", "corr-only", "binary-only", "all", "titrated"], - help='name of the agent to evaluate') - parser.add_argument("-T", "--teacher", type=str, dest='teacher_name', default="optimal", choices=["optimal"], - help='name of the simulated teacher to query') - parser.add_argument("-O", "--output", type=str, dest='output_dir', default="output", - help='name of the output directory') - parser.add_argument("--output_name", type=str, dest='output_name', - help='name of the output filename') - parser.add_argument("--seed_with_n_demos", type=int, dest="n_demos", default=1, - help="how many demos to provide before commencing preference queries. Specific to DemPref.") + parser = argparse.ArgumentParser( + description="Parameters for evaluating INQUIRE" + ) + parser.add_argument( + "--actual_queries", dest="actual_queries", type=int, default=0 + ) + parser.add_argument( + "-V", + "--verbose", + dest="verbose", + action="store_true", + help="verbose", + ) + parser.add_argument( + "--use_cache", + dest="use_cache", + action="store_true", + help="use cached trajectories instead of sampling", + ) + parser.add_argument( + "--numba", + dest="use_numba", + action="store_true", + help="use cached trajectories instead of sampling", + ) + parser.add_argument( + "--static_state", + dest="static_state", + action="store_true", + help="use the same state for all queries", + ) + parser.add_argument( + "--teacher_displays", + action="store_true", + help="display the teacher's interactions.", + ) + parser.add_argument( + "-K", + "--queries", + type=int, + dest="num_queries", + default=5, + help="number of queries", + ) + parser.add_argument( + "-R", + "--runs", + type=int, + dest="num_runs", + default=10, + help="number of evaluations to run for each task", + ) + parser.add_argument( + "-X", + "--tests", + type=int, + dest="num_test_states", + default=1, + help="number of test states to evaluate", + ) + parser.add_argument( + "-Z", + "--tasks", + type=int, + dest="num_tasks", + default=1, + help="number of task instances to generate", + ) + parser.add_argument( + "--beta", + type=float, + dest="beta", + default=1.0, + help="optimality parameter", + ) + parser.add_argument( + "--betas", + type=str, + dest="beta_vals", + help="optimality parameter, set for each interaction type", + ) + parser.add_argument( + "--costs", + type=str, + dest="cost_vals", + help="cost per interaction, set for each interaction type", + ) + parser.add_argument( + "-C", + "--convergence_threshold", + type=float, + dest="conv_threshold", + default=1.0e-1, + help="convergence threshold for estimating weight distribution", + ) + parser.add_argument( + "--alpha", + type=float, + dest="step_size", + default=0.05, + help="step size for weight optimization", + ) + parser.add_argument( + "-M", + type=int, + dest="num_w_samples", + default=100, + help="number of weight samples", + ) + parser.add_argument( + "-N", + type=int, + dest="num_traj_samples", + default=50, + help="number of trajectory samples", + ) + parser.add_argument( + "-D", + "--domain", + type=str, + dest="domain_name", + default="linear_combo", + choices=[ + "lander", + "linear_combo", + "linear_system", + "pats_linear_system", + "pizza", + ], + help="name of the evaluation domain", + ) + parser.add_argument( + "-I", + "--opt_iterations", + type=int, + dest="opt_iters", + default=50, + help="number of attempts to optimize a sample of controls (pertinent to lunar lander, linear system, and pizza-making domains)", + ) + parser.add_argument( + "-S", + "--sampling", + type=str, + dest="sampling_method", + default="uniform", + choices=["uniform"], + help="name of the trajectory sampling method", + ) + parser.add_argument( + "-A", + "--agent", + type=str, + dest="agent_name", + default="inquire", + choices=[ + "inquire", + "dempref", + "biased_dempref", + "no-demos", + "demo-only", + "pref-only", + "corr-only", + "binary-only", + "all", + "titrated", + ], + help="name of the agent to evaluate", + ) + parser.add_argument( + "-T", + "--teacher", + type=str, + dest="teacher_name", + default="optimal", + choices=["optimal"], + help="name of the simulated teacher to query", + ) + parser.add_argument( + "-O", + "--output", + type=str, + dest="output_dir", + default="output", + help="name of the output directory", + ) + parser.add_argument( + "--output_name", + type=str, + dest="output_name", + help="name of the output filename", + ) + parser.add_argument( + "--seed_with_n_demos", + type=int, + dest="n_demos", + default=1, + help="how many demos to provide before commencing preference queries. Specific to DemPref.", + ) self._args = parser.parse_args() @@ -75,52 +225,60 @@ def __init__(self): self.verbose = self._args.verbose self.output_name = self._args.output_name self.output_dir = self._args.output_dir + self.actual_queries = self._args.actual_queries def setup_domain(self): ## Set up domain if self._args.domain_name == "linear_combo": from inquire.environments.linear_combo import LinearCombination + seed = 42 - w_dim = 8 + w_dim = 8 domain = LinearCombination(seed, w_dim) elif self._args.domain_name == "lander": from inquire.environments.lunar_lander import LunarLander + traj_length = 10 optimization_iteration_count = self._args.opt_iters if self._args.agent_name == "biased_dempref": domain = LunarLander( optimal_trajectory_iterations=optimization_iteration_count, verbose=self._args.verbose, - include_feature_biases=True + include_feature_biases=True, ) else: domain = LunarLander( optimal_trajectory_iterations=optimization_iteration_count, - verbose=self._args.verbose + verbose=self._args.verbose, ) elif self._args.domain_name == "pats_linear_system": - from inquire.environments.pats_linear_dynamical_system import PatsLinearDynamicalSystem + from inquire.environments.pats_linear_dynamical_system import \ + PatsLinearDynamicalSystem + traj_length = 15 optimization_iteration_count = self._args.opt_iters domain = PatsLinearDynamicalSystem( trajectory_length=traj_length, optimal_trajectory_iterations=optimization_iteration_count, - verbose=self._args.verbose + verbose=self._args.verbose, ) elif self._args.domain_name == "linear_system": - from inquire.environments.linear_dynamical_system import LinearDynamicalSystem + from inquire.environments.linear_dynamical_system import \ + LinearDynamicalSystem + traj_length = 15 optimization_iteration_count = self._args.opt_iters domain = LinearDynamicalSystem( trajectory_length=traj_length, optimal_trajectory_iterations=optimization_iteration_count, - verbose=self._args.verbose + verbose=self._args.verbose, ) elif self._args.domain_name == "pizza": from inquire.environments.pizza_making import PizzaMaking + traj_length = how_many_toppings_to_add = 1 max_topping_count = 15 pizza_form = { @@ -155,24 +313,77 @@ def setup_agents(self): ## Set up agent(s) if self._args.agent_name == "titrated": from inquire.agents.inquire import FixedInteractions - ddddd = FixedInteractions(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.DEMONSTRATION]*5) - ddddp = FixedInteractions(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.DEMONSTRATION]*4 + [Modality.PREFERENCE]) - dddpp = FixedInteractions(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.DEMONSTRATION]*3 + [Modality.PREFERENCE]*2) - ddppp = FixedInteractions(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.DEMONSTRATION]*2 + [Modality.PREFERENCE]*3) - dpppp = FixedInteractions(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.DEMONSTRATION] + [Modality.PREFERENCE]*4) - ppppp = FixedInteractions(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.PREFERENCE]*5) - agents = [ddddd, ddddp, dddpp, ddppp, dpppp, ppppp] - agent_names = ["DDDDD", "DDDDP", "DDDPP", "DDPPP", "DPPPP", "PPPPP"] - if self._args.agent_name.lower() == "dempref" or self._args.agent_name.lower() == "biased_dempref": + + ddddd = FixedInteractions( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [Modality.DEMONSTRATION] * 5, + ) + ddddp = FixedInteractions( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [Modality.DEMONSTRATION] * 4 + [Modality.PREFERENCE], + ) + dddpp = FixedInteractions( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [Modality.DEMONSTRATION] * 3 + [Modality.PREFERENCE] * 2, + ) + ddppp = FixedInteractions( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [Modality.DEMONSTRATION] * 2 + [Modality.PREFERENCE] * 3, + ) + dpppp = FixedInteractions( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [Modality.DEMONSTRATION] + [Modality.PREFERENCE] * 4, + ) + ppppp = FixedInteractions( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [Modality.PREFERENCE] * 5, + ) + agents = [ddddd, ddddp, dddpp, ddppp, dpppp, ppppp] + agent_names = [ + "DDDDD", + "DDDDP", + "DDDPP", + "DDPPP", + "DPPPP", + "PPPPP", + ] + if ( + self._args.agent_name.lower() == "dempref" + or self._args.agent_name.lower() == "biased_dempref" + ): from inquire.agents.dempref import DemPref - agents = [DemPref( + + agents = [ + DemPref( weight_sample_count=self._args.num_w_samples, trajectory_sample_count=self._args.num_traj_samples, - interaction_types=[Modality.DEMONSTRATION, Modality.PREFERENCE], + interaction_types=[ + Modality.DEMONSTRATION, + Modality.PREFERENCE, + ], w_dim=self.w_dim, seed_with_n_demos=self._args.n_demos, - domain_name=self._args.domain_name - )] + domain_name=self._args.domain_name, + ) + ] agent_names = ["DEMPREF"] if self._args.beta_vals is None: beta = self._args.beta @@ -185,27 +396,104 @@ def setup_agents(self): use_numba = self._args.use_numba if self._args.agent_name == "inquire": from inquire.agents.inquire import Inquire - agents = [Inquire(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.DEMONSTRATION, Modality.PREFERENCE, Modality.CORRECTION, Modality.BINARY], beta=beta, costs=costs, use_numba=use_numba)] + + agents = [ + Inquire( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [ + Modality.DEMONSTRATION, + Modality.PREFERENCE, + Modality.CORRECTION, + Modality.BINARY, + ], + beta=beta, + costs=costs, + use_numba=use_numba, + ) + ] agent_names = ["INQUIRE"] elif self._args.agent_name == "no-demos": from inquire.agents.inquire import Inquire - agents = [Inquire(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.PREFERENCE, Modality.CORRECTION, Modality.BINARY], beta=beta, costs=costs, use_numba=use_numba)] + + agents = [ + Inquire( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [ + Modality.PREFERENCE, + Modality.CORRECTION, + Modality.BINARY, + ], + beta=beta, + costs=costs, + use_numba=use_numba, + ) + ] agent_names = ["INQUIRE wo/Demos"] elif self._args.agent_name == "demo-only": from inquire.agents.inquire import Inquire - agents = [Inquire(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.DEMONSTRATION], beta=beta, use_numba=use_numba)] + + agents = [ + Inquire( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [Modality.DEMONSTRATION], + beta=beta, + use_numba=use_numba, + ) + ] agent_names = ["Demo-only"] elif self._args.agent_name == "pref-only": from inquire.agents.inquire import Inquire - agents = [Inquire(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.PREFERENCE], beta=beta, use_numba=use_numba)] + + agents = [ + Inquire( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [Modality.PREFERENCE], + beta=beta, + use_numba=use_numba, + ) + ] agent_names = ["Pref-only"] elif self._args.agent_name == "corr-only": from inquire.agents.inquire import Inquire - agents = [Inquire(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.CORRECTION], beta=beta, use_numba=use_numba)] + + agents = [ + Inquire( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [Modality.CORRECTION], + beta=beta, + use_numba=use_numba, + ) + ] agent_names = ["Corr-only"] elif self._args.agent_name == "binary-only": from inquire.agents.inquire import Inquire - agents = [Inquire(sampling_method, sampling_params, self._args.num_w_samples, self._args.num_traj_samples, [Modality.BINARY], beta=beta, use_numba=use_numba)] + + agents = [ + Inquire( + sampling_method, + sampling_params, + self._args.num_w_samples, + self._args.num_traj_samples, + [Modality.BINARY], + beta=beta, + use_numba=use_numba, + ) + ] agent_names = ["Binary-only"] return agents, agent_names @@ -213,7 +501,8 @@ def setup_teacher(self): ## Set up teacher if self._args.teacher_name == "optimal": from inquire.teachers.optimal import OptimalTeacher + teacher = OptimalTeacher( - self._args.num_traj_samples, self._args.teacher_displays - ) + self._args.num_traj_samples, self._args.teacher_displays + ) return teacher diff --git a/tests/corl22.py b/tests/corl22.py index f0b4ed5..f940843 100644 --- a/tests/corl22.py +++ b/tests/corl22.py @@ -3,10 +3,12 @@ import os import pdb import time + import numpy as np + from args_handler import ArgsHandler -from evaluation import Evaluation from data_utils import save_data, save_plot +from evaluation import Evaluation if __name__ == "__main__": args = ArgsHandler() @@ -23,12 +25,33 @@ eval_start_time = time.strftime("_%m:%d:%H:%M", time.localtime()) for agent, name in zip(agents, agent_names): print("Evaluating " + name + " agent... ") - perf, dist, q_type = Evaluation.run(domain, teacher, agent, args.num_tasks, args.num_runs, args.num_queries, args.num_test_states, args.step_size, args.conv_threshold, args.use_cache, args.static_state, args.verbose) + perf, dist, q_type, dempref_metric = Evaluation.run( + domain, + teacher, + agent, + args.num_tasks, + args.num_runs, + args.num_queries, + args.num_test_states, + args.step_size, + args.conv_threshold, + args.use_cache, + args.static_state, + args.verbose, + args.actual_queries, + ) if args.output_name is not None: dist_sum = np.sum(dist) perf_sum = np.sum(perf) - with open(args.output_dir + '/' + "overview.txt", "a+") as f: - f.write(args.output_name + ", " + str(dist_sum) + ", " + str(perf_sum) + '\n') + with open(args.output_dir + "/" + "overview.txt", "a+") as f: + f.write( + args.output_name + + ", " + + str(dist_sum) + + ", " + + str(perf_sum) + + "\n" + ) data["distance"].append(dist) data["performance"].append(perf) data["query_types"].append(q_type) @@ -46,17 +69,17 @@ num_runs=args.num_runs, directory=args.output_dir, filename=name + f"_{d}.csv", - subdirectory=domain.__class__.__name__ + subdirectory=domain.__class__.__name__, ) try: save_plot( data["distance"], agent_names, "w distance", - [0,1], + [0, 1], args.output_dir, name + "_distance.png", - subdirectory=domain.__class__.__name__ + subdirectory=domain.__class__.__name__, ) except: print("save_plot() didn't work.") diff --git a/tests/evaluation.py b/tests/evaluation.py index 1045f3f..b3c2a7c 100644 --- a/tests/evaluation.py +++ b/tests/evaluation.py @@ -1,30 +1,49 @@ -import pickle import os import pdb +import pickle import time from pathlib import Path -from inquire.environments.environment import Task, CachedTask -from inquire.utils.datatypes import Modality, CachedSamples import numpy as np -from numpy.random import RandomState import pandas as pd +from inquire.environments.environment import CachedTask, Task +from inquire.utils.datatypes import CachedSamples, Modality +from numpy.random import RandomState + class Evaluation: @staticmethod - def run(domain, teacher, agent, num_tasks, num_runs, num_queries, num_test_states, step_size, convergence_threshold, use_cached_trajectories=False, static_state=False, verbose=False): + def run( + domain, + teacher, + agent, + num_tasks, + num_runs, + num_queries, + num_test_states, + step_size, + convergence_threshold, + use_cached_trajectories=False, + static_state=False, + verbose=False, + actual_queries=None, + ): test_state_rand = RandomState(0) init_w_rand = RandomState(0) - perf_mat = np.zeros((num_tasks,num_runs,num_test_states,num_queries+1)) - dist_mat = np.zeros((num_tasks,num_runs,1,num_queries+1)) - query_mat = np.zeros((num_tasks,num_runs,1,num_queries+1)) - dempref_mat = np.zeros((num_tasks,num_runs,1,num_queries)) - debug = False + real_num_queries = num_queries + if actual_queries != None: + num_queries = actual_queries + perf_mat = np.zeros( + (num_tasks, num_runs, num_test_states, num_queries + 1) + ) + dist_mat = np.zeros((num_tasks, num_runs, 1, num_queries + 1)) + query_mat = np.zeros((num_tasks, num_runs, 1, num_queries + 1)) + debug = True if static_state: query_states = 1 else: - query_states = num_queries + query_states = real_num_queries if verbose: print("Initializing tasks...") if use_cached_trajectories: @@ -32,17 +51,36 @@ def run(domain, teacher, agent, num_tasks, num_runs, num_queries, num_test_state for i in range(num_tasks): state_samples = [] for j in range((num_runs * query_states) + num_test_states): - f_name = "cache/" + domain.__class__.__name__ + "_task-" + str(i) \ - + "_state-" + str(j) + "_cache.pkl" + f_name = ( + "cache/" + + domain.__class__.__name__ + + "_task-" + + str(i) + + "_state-" + + str(j) + + "_cache.pkl" + ) if os.path.isfile(f_name): f = open(f_name, "rb") state_samples.append(pickle.load(f)) f.close() else: raise FileNotFoundError(f_name) - tasks.append(CachedTask(state_samples, num_runs * query_states, num_test_states)) + tasks.append( + CachedTask( + state_samples, num_runs * query_states, num_test_states + ) + ) else: - tasks = [Task(domain, num_runs * query_states, num_test_states, test_state_rand) for _ in range(num_tasks)] + tasks = [ + Task( + domain, + num_runs * query_states, + num_test_states, + test_state_rand, + ) + for _ in range(num_tasks) + ] all_w_opt = [] @@ -65,11 +103,21 @@ def run(domain, teacher, agent, num_tasks, num_runs, num_queries, num_test_state if verbose: print("Finding optimal and worst-case baselines...") for p, test_state in enumerate(task.test_states): - best_traj = task.optimal_trajectory_from_ground_truth(test_state) + best_traj = task.optimal_trajectory_from_ground_truth( + test_state + ) max_reward = task.ground_truth_reward(best_traj) - worst_traj = task.least_optimal_trajectory_from_ground_truth(test_state) + worst_traj = task.least_optimal_trajectory_from_ground_truth( + test_state + ) min_reward = task.ground_truth_reward(worst_traj) - test_set.append([test_state, (worst_traj, best_traj), (min_reward, max_reward)]) + test_set.append( + [ + test_state, + (worst_traj, best_traj), + (min_reward, max_reward), + ] + ) if debug: print(f"Finished {p+1}.") @@ -83,7 +131,9 @@ def run(domain, teacher, agent, num_tasks, num_runs, num_queries, num_test_state feedback = [] if agent.__class__.__name__.lower() == "dempref": for _ in range(agent.n_demos): - q = agent.generate_demo_query(task.query_states[state_idx], domain,) + q = agent.generate_demo_query( + task.query_states[state_idx], domain, + ) teacher_fb = teacher.query_response(q, task, verbose) if teacher_fb is not None: feedback.append(teacher_fb) @@ -93,60 +143,93 @@ def run(domain, teacher, agent, num_tasks, num_runs, num_queries, num_test_state w_mean = np.mean(w_dist, axis=0) print(f"w0: {w_mean}") for c in range(num_test_states): - model_traj = domain.optimal_trajectory_from_w(test_set[c][0], w_mean) + model_traj = domain.optimal_trajectory_from_w( + test_set[c][0], w_mean + ) reward = task.ground_truth_reward(model_traj) min_r, max_r = test_set[c][2] perfs.append((reward - min_r) / (max_r - min_r)) if perfs[-1] < -0.1 or perfs[-1] > 1.1: pdb.set_trace() - #assert 0 <= perfs[-1] <= 1 + # assert 0 <= perfs[-1] <= 1 perf_mat[t, r, :, 0] = perfs dist_mat[t, r, 0, 0] = task.distance_from_ground_truth(w_mean) query_mat[t, r, 0, 0] = Modality.NONE.value ## Iterate through queries for k in range(num_queries): q_start = time.perf_counter() - print("\nTask " + str(t+1) + "/" + str(num_tasks) + ", Run " + str(r+1) + "/" + str(num_runs) + ", Query " + str(k+1) + "/" + str(num_queries) + " ", end='\n') + print( + "\nTask " + + str(t + 1) + + "/" + + str(num_tasks) + + ", Run " + + str(r + 1) + + "/" + + str(num_runs) + + ", Query " + + str(k + 1) + + "/" + + str(num_queries) + + " ", + end="\n", + ) ## Generate query and learn from feedback - q = agent.generate_query(domain, task.query_states[state_idx], w_dist, verbose) + q = agent.generate_query( + domain, task.query_states[state_idx], w_dist, verbose + ) state_idx += 1 teacher_fb = teacher.query_response(q, task, verbose) if teacher_fb is not None: feedback.append(teacher_fb) - w_dist, w_opt = agent.update_weights(w_dist, domain, feedback, learning_rate=step_size, sample_threshold=convergence_threshold, opt_threshold=1.0e-5) + + w_dist, w_opt = agent.update_weights( + w_dist, + domain, + feedback, + learning_rate=step_size, + sample_threshold=convergence_threshold, + opt_threshold=1.0e-5, + ) print(f"w after query {k+1}: {w_opt.mean(axis=0)}") ## Get performance metrics for each test-state after ## each query and corresponding weight update: perfs = [] for c in range(num_test_states): - model_traj = domain.optimal_trajectory_from_w(test_set[c][0], np.mean(w_opt,axis=0)) + model_traj = domain.optimal_trajectory_from_w( + test_set[c][0], np.mean(w_opt, axis=0) + ) reward = task.ground_truth_reward(model_traj) min_r, max_r = test_set[c][2] if k > 0 and debug: - print(f"Min {min_r}\nMax: {max_r}\nActual: {reward}") + print( + f"Min {min_r}\nMax: {max_r}\nActual: {reward}" + ) perfs.append((reward - min_r) / (max_r - min_r)) # assert 0 <= perf <= 1 - perf_mat[t, r, :, k+1] = perfs - latest_dist = task.distance_from_ground_truth(np.mean(w_opt,axis=0)) - dist_mat[t, r, 0, k+1] = latest_dist + perf_mat[t, r, :, k + 1] = perfs + latest_dist = task.distance_from_ground_truth( + np.mean(w_opt, axis=0) + ) + dist_mat[t, r, 0, k + 1] = latest_dist # dist_mat[t, r, 0, k+1] = task.distance_from_ground_truth(np.mean(w_opt,axis=0)) if k > 0 and debug: print(f"Latest dist: {latest_dist}.") - query_mat[t, r, 0, k+1] = q.query_type.value - dp_met= task.dempref_metric(w_dist) - dempref_mat[t, r, 0, k] = dp_met - all_w_opt.append(w_opt.mean(axis=0).reshape(1,-1)) - if k > 0 and debug: - print(f"Latest dempref metric: {dp_met}.") + all_w_opt.append(w_opt.mean(axis=0).reshape(1, -1)) + query_mat[t, r, 0, k + 1] = q.query_type.value q_time = time.perf_counter() - q_start if verbose: - print(f"Query {k+1} in task {t+1}, run {r+1} took " - f"{q_time:.4}s to complete.") + print( + f"Query {k+1} in task {t+1}, run {r+1} took " + f"{q_time:.4}s to complete." + ) run_time = time.perf_counter() - run_start if verbose: - print(f"Run {r+1} in task {t+1} took {run_time:.4}s " - "to complete.") + print( + f"Run {r+1} in task {t+1} took {run_time:.4}s " + "to complete." + ) task_time = time.perf_counter() - task_start if verbose: @@ -154,10 +237,10 @@ def run(domain, teacher, agent, num_tasks, num_runs, num_queries, num_test_state # learned_w_toppings = domain.make_pizza(np.mean(w_opt, axis=0)) # domain.visualize_pizza(learned_w_toppings) - all_w_opt = np.asarray(all_w_opt).reshape((num_queries*num_runs*num_tasks), domain.w_dim()) - df = pd.DataFrame(all_w_opt) - real_world_path = Path.cwd() / Path("output/RealWorld/") - if not real_world_path.exists(): - real_world_path.mkdir(parents=True) - df.to_csv(str(real_world_path) + "/" + agent.__class__.__name__ + ".csv") + # all_w_opt = np.asarray(all_w_opt).reshape((num_queries*num_runs*num_tasks), domain.w_dim()) + # df = pd.DataFrame(all_w_opt) + # real_world_path = Path.cwd() / Path("output/RealWorld/") + # if not real_world_path.exists(): + # real_world_path.mkdir(parents=True) + # df.to_csv(str(real_world_path) + "/" + agent.__class__.__name__ + ".csv") return perf_mat, dist_mat, query_mat