1
0
mirror of https://github.com/osmarks/rl-testing.git synced 2024-11-10 22:29:54 +00:00

initial commit

This commit is contained in:
osmarks 2024-07-29 13:40:57 +01:00
commit 94a22e78c4
4 changed files with 120 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
venv

2
deploy.sh Executable file
View File

@ -0,0 +1,2 @@
#!/bin/sh
rsync -vr {requirements.txt,src/*.py} protagonism:rl-testing

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
optax==0.2.3
jax<0.5
gymnasium==0.29
equinox
jaxtyping
haliax
gymnax

110
src/run.py Normal file
View File

@ -0,0 +1,110 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import haliax as hax
import numpy as np
import haliax.nn as hnn
import gymnasium as gym
import optax
import matplotlib.pyplot as plt
class MLP(eqx.Module):
input: hnn.Linear
hidden: [hnn.Linear]
output: hnn.Linear
@staticmethod
def init(Input, Hidden, Output, key, n_layers):
i, o, *hs = jax.random.split(key, n_layers + 2)
return MLP(
input=hnn.Linear.init(In=Input, Out=Hidden.alias("hidden0"), key=i),
hidden=[hnn.Linear.init(In=Hidden.alias(f"hidden{i}"), Out=Hidden.alias(f"hidden{i+1}"), key=h) for i, h in enumerate(hs)],
output=hnn.Linear.init(In=Hidden.alias(f"hidden{n_layers}"), Out=Output, key=o)
)
@eqx.filter_jit
def __call__(self, x):
x = self.input(x)
for h in self.hidden:
x = hnn.relu(x)
x = h(x)
return self.output(x)
env = gym.make("LunarLander-v2")
Observation = hax.Axis("observation", env.observation_space.shape[0])
Action = hax.Axis("action", int(env.action_space.n))
Hidden = hax.Axis("hidden", 32)
Batch = hax.Axis("batch", 5000)
def run():
params, sample = jax.random.split(jax.random.PRNGKey(0))
policy_net = MLP.init(Input=Observation, Hidden=Hidden, Output=Action, key=params, n_layers=1)
optimizer = optax.chain(optax.clip(1), optax.adam(1e-2))
opt_state = optimizer.init(policy_net)
def compute_loss(policy_net, obs: hax.NamedArray, act: hax.NamedArray, weights: hax.NamedArray):
logits = policy_net(obs)
logprobs = hax.take(logits, Action, act) - hax.log(hax.sum(hax.exp(logits), Action) + 1e-5)
return -hax.mean(weights * logprobs, axis=Batch).scalar()
@eqx.filter_jit
def do_train_step(model, opt_state, observations, actions, weights):
loss, grad = eqx.filter_value_and_grad(compute_loss)(model, observations, actions, weights)
updates, opt_state = optimizer.update(grad, opt_state)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
mean_return_by_ep = []
for i in range(500):
observations = []
actions = []
weights = []
returns = []
lengths = []
episode_rewards = []
obs, info = env.reset()
done = False
while True:
observations.append(obs)
act = policy_net(hax.named(jnp.array(obs), (Observation,)))
#print(act)
sample, rk = jax.random.split(sample)
act = hax.random.categorical(rk, act, Action).item()
obs, rew, done, truncated, info = env.step(act)
actions.append(act)
episode_rewards.append(rew)
if done:
episode_return = sum(episode_rewards)
returns.append(episode_return)
lengths.append(len(episode_rewards))
weights += [episode_return] * len(episode_rewards)
episode_rewards = []
obs, info = env.reset()
done = False
if len(observations) > Batch.size:
break
observations = hax.named(jnp.array(observations[:Batch.size]), (Batch, Observation))
actions = hax.named(jnp.array(actions[:Batch.size]), (Batch,))
weights = hax.named(jnp.array(weights[:Batch.size]), (Batch,))
weights = (weights - hax.mean(weights, axis=Batch)) / (hax.std(weights, axis=Batch) + 1e-5)
policy_net, opt_state, loss = do_train_step(policy_net, opt_state, observations, actions, weights)
print(i, loss, np.mean(returns), np.mean(lengths))
mean_return_by_ep.append(np.mean(returns))
plt.plot(mean_return_by_ep)
plt.show()
run()