Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add In-Context RL eval #1491

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions evals/elsuite/incontext_rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# In-Context RL

This eval tests models' ability to solve RL environments simply by interacting with them in-context, without dedicated training or fine-tuning.

## Usage

Run with:

```bash
oaieval <solver> incontext_rl
```

For examples of tested solvers, see [`./scripts/run_experiments.sh`](./scripts/run_experiments.sh).

## Dataset

The eval is currently set up to test models on the following canonical RL environments:
1. [FrozenLake-v1](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) (non-slippery version, default map), 4x4 gridworld where the agent has to reach the goal without falling into traps.
2. [CliffWalking-v0](https://gymnasium.farama.org/environments/toy_text/cliff_walking/). 4x12 gridworld where the agent has to reach the other side of the map without falling off a cliff.
3. [BanditTwoArmedHighLowFixed-v1](https://github.com/james-aung/gymasium-bandits). Stochastic two-armed bandit setup where Arm 1 pays out 80% of the time with reward 1, and Arm 2 pays out 20% of the time with reward 1.
4. [BanditTenArmedRandomFixed-v1](https://github.com/james-aung/gymasium-bandits). Stochastic ten-armed bandit setup where each arm has some randomly-initialized probability of payout.

Besides these four environments, our eval is also built to be compatible with any environments that have discrete action and observation spaces using the Gymnasium API. Future work may generalize our eval to work with environments with other types of action/observation spaces.

## Evaluation Process

Each run of the eval tests the model on all four environments in the dataset, and has the model take steps in each environment until 200 steps are taken or the model’s context limit is reached.

At each step, the eval provides the following to the model:
- The next observation and the reward from the last action. The model is also told when the environment has reset due to its action leading to a termination.
- How many of the maximum number of steps it has already taken.
- The total reward it has accumulated so far across all episodes.

If an episode ends, the environment resets and a new episode begins.

If the eval receive 4 responses in a row where we cannot parse an action selection, we end the evaluation for that environment. (This provides a natural end for runs where the model’s context window is exceeded.)


## Prompts

We refer readers to the [`./defaults.py`](./defaults.py) file for the `TASK_DESCRIPTION` and other prompts used in the eval.

## Metrics
<!-- prettier-ignore-start -->
We provide the following metrics per evaluated environment:

| **Metric** | **Notes** |
|---|---|
| `average_episode_reward` | The average reward achieved per episode |
| `total_steps` | The number of steps taken across all episodes before the environment sample ended |
| `invalid_response_rate` | % of responses that were in an invalid format for the eval |
<!-- prettier-ignore-end -->

## Token Usage Estimates

<!-- prettier-ignore-start -->
| Model | Token Usage Per Run |
|---|---|
| **gpt-3.5-turbo** | 4200000 ± 400000 |
| **gpt-4-turbo-preview** | 21900000 ± 10100000 |
| **mixtral-8x7b** | 2700000 ± 800000 |
<!-- prettier-ignore-end -->

## Future modifications

- Extend the eval to work with other observation and action spaces beyond Discrete spaces

## Version History

- v0: Initial version released

## Contribution Statement

Eval design, implementation, and results evaluation were primarily conducted by James Aung. Chan Jun Shern was responsible for code reviews throughout the implementation process, along with fine-grained feedback on the project in general. Additional guidance was provided by Steven Adler, who scoped and managed the broader research project, including input on evaluation design, results analysis, and interpretation.
38 changes: 38 additions & 0 deletions evals/elsuite/incontext_rl/anti-cot_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Any
from evals.solvers.solver import NestedSolver, Solver, SolverResult, SolverSpec
from evals.task_state import Message, TaskState

ANTI_COT_TEMPLATE = "RESPOND ONLY WITH YOUR FINAL ANSWER IN THE FORMAT REQUESTED. DO NOT OUTPUT ANY ADDITIONAL REASONING OR TEXT."

class AntiCoTSolver(NestedSolver):
"""
Instructs the model to not do any further reasoning and just respond with the final answer.
"""

def __init__(
self,
solver: SolverSpec,
registry: Any = None,
):
super().__init__(solver=solver)

@property
def solver(self) -> Solver:
return self.get_solver("solver")

def _solve(
self,
task_state: TaskState,
**kwargs,
) -> SolverResult:
task_state.messages += (
[
Message(role="system", content=ANTI_COT_TEMPLATE),
]
)
solver_result = self.solver(task_state=task_state, **kwargs)
return solver_result

@property
def name(self) -> str:
return f"Anti-CoT_{self.solver.name}"
118 changes: 118 additions & 0 deletions evals/elsuite/incontext_rl/baselines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import random

import numpy as np

from evals.elsuite.incontext_rl.eval import CurrentState
from evals.record import record_sampling
from evals.solvers.solver import Solver, SolverResult
from evals.task_state import TaskState


class RandomSolver(Solver):
def __init__(self, *args, **kwargs):
pass

def _solve(
self,
task_state: TaskState,
**kwargs,
) -> SolverResult:

cs: CurrentState = task_state.current_state

try:
action = cs.action_space.sample()
response = f"[SELECT: {action}]"
except Exception as e:
response = f"Error: {e}"

record_sampling(
prompt=cs.observations[-1],
sampled=response,
model="incontext_rl_random",
)

return SolverResult(response)


class QlearningSolver(Solver):
def __init__(
self,
learning_rate=0.7,
gamma=0.95,
epsilon=1.0,
min_epsilon=0.05,
max_epsilon=1.0,
decay_rate=0.0005,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.learning_rate = learning_rate
self.gamma = gamma
self.epsilon = epsilon
self.min_epsilon = min_epsilon
self.max_epsilon = max_epsilon
self.decay_rate = decay_rate
self.q_table = None

def initialize_q_table(self, observation_space_size, action_space_size):
self.q_table = np.zeros((observation_space_size, action_space_size))

def select_action(self, state, action_space):
if random.uniform(0, 1) < self.epsilon:
return action_space.sample() # Explore action space
else:
return np.argmax(self.q_table[state][:]) # Exploit learned values

def update_q_table(self, state, action, reward, next_state):
next_max = np.max(self.q_table[next_state])
self.q_table[state, action] = self.q_table[state, action] + self.learning_rate * (
reward + self.gamma * next_max - self.q_table[state, action]
)

def reduce_epsilon(self, episode_number):
self.epsilon = self.min_epsilon + (self.max_epsilon - self.min_epsilon) * np.exp(
-self.decay_rate * episode_number
)

def _solve(self, task_state: TaskState, **kwargs) -> SolverResult:

cs: CurrentState = task_state.current_state

# TODO these might not be true if environment is not discrete
assert (
cs.observation_space_n is not None
), "Environment must have discrete observation space"
assert cs.action_space_n is not None, "Environment must have discrete action space"

if self.q_table is None:
print("Initializing Q-table")
self.initialize_q_table(
observation_space_size=cs.observation_space_n, action_space_size=cs.action_space_n
)

# This shouln't run on the first step
if len(cs.actions) >= 1 and len(cs.rewards) >= 1 and len(cs.observations) >= 2:
print(cs.actions)
self.update_q_table(
state=cs.observations[-2],
action=cs.actions[-1],
reward=cs.rewards[-1],
next_state=cs.observations[-1],
)
print(
f"The last action {cs.actions[-1]} resulted in reward {cs.rewards[-1]}. We went from state {cs.observations[-2]} to state {cs.observations[-1]}"
)
self.reduce_epsilon(episode_number=len(cs.episode_end_steps))

action = self.select_action(state=cs.observations[-1], action_space=cs.action_space)
response = f"[SELECT: {action}]"

record_sampling(
prompt=cs.observations[-1],
sampled=response,
model="incontext_rl_qlearning",
)

return SolverResult(response)
30 changes: 30 additions & 0 deletions evals/elsuite/incontext_rl/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from string import Template

task_description_template = Template("""$explanations

You have the following $action_space actions available to you: $actions.

On each turn, you will receive an observation from the environment, which maybe be one of the following $observation_space possible observations:
$observations
Each action you take will give you a reward (which might be negative) and may move you to a new state.

You should take actions such that you accumulate the highest possible reward across all your actions.

This is how we will play the game:
1. You will consider all the information you have received to decide upon the best action you can take now.
2. You will select your action from the list above by specifying the number key of the action in the command [SELECT: x], where x is the number key of the action.
3. Your selected action will be taken.
4. As a result of your action, you will be given an observation from the environment and you may receive some reward.
5. Repeat from step 1.
""")

step_counter = Template("Total actions taken so far: $step_count")
reward_counter = Template("Total reward so far: $reward_count")
reset_msg = Template("""After the game reset you are now in $observation.
Please pick an action, providing your reasoning. You must format your final action choice as [SELECT: x]""")
step_result = Template("""You took Action $action. You are now in $next_observation.
The last step you did provided reward: $reward.
Please pick an action, providing your reasoning. You must format your final action choice as [SELECT: x]""")
step_result_reset = Template("""You took Action $action. You arrived at $next_observation.
The last step made the game reset.
The last step you did provided reward: $reward.""")
12 changes: 12 additions & 0 deletions evals/elsuite/incontext_rl/env_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Optional setup scripts for specific environments.
"""

def setup_GymnasiumBandits():
import gymnasium_bandits
return

ENV_SETUP_FUNCS = {
"BanditTwoArmedHighLowFixed-v0": setup_GymnasiumBandits,
"BanditTenArmedRandomFixed-v0": setup_GymnasiumBandits,
}
Loading
Loading