1
0
mirror of https://github.com/osmarks/rl-testing.git synced 2026-06-13 00:02:11 +00:00

initial commit

This commit is contained in:
2024-07-29 13:40:57 +01:00
commit 94a22e78c4
4 changed files with 120 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
venv
Executable
+2
View File
@@ -0,0 +1,2 @@
#!/bin/sh
rsync -vr {requirements.txt,src/*.py} protagonism:rl-testing
+7
View File
@@ -0,0 +1,7 @@
optax==0.2.3
jax<0.5
gymnasium==0.29
equinox
jaxtyping
haliax
gymnax
+110
View File
@@ -0,0 +1,110 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import haliax as hax
import numpy as np
import haliax.nn as hnn
import gymnasium as gym
import optax
import matplotlib.pyplot as plt
class MLP(eqx.Module):
input: hnn.Linear
hidden: [hnn.Linear]
output: hnn.Linear
@staticmethod
def init(Input, Hidden, Output, key, n_layers):
i, o, *hs = jax.random.split(key, n_layers + 2)
return MLP(
input=hnn.Linear.init(In=Input, Out=Hidden.alias("hidden0"), key=i),
hidden=[hnn.Linear.init(In=Hidden.alias(f"hidden{i}"), Out=Hidden.alias(f"hidden{i+1}"), key=h) for i, h in enumerate(hs)],
output=hnn.Linear.init(In=Hidden.alias(f"hidden{n_layers}"), Out=Output, key=o)
)
@eqx.filter_jit
def __call__(self, x):
x = self.input(x)
for h in self.hidden:
x = hnn.relu(x)
x = h(x)
return self.output(x)
env = gym.make("LunarLander-v2")
Observation = hax.Axis("observation", env.observation_space.shape[0])
Action = hax.Axis("action", int(env.action_space.n))
Hidden = hax.Axis("hidden", 32)
Batch = hax.Axis("batch", 5000)
def run():
params, sample = jax.random.split(jax.random.PRNGKey(0))
policy_net = MLP.init(Input=Observation, Hidden=Hidden, Output=Action, key=params, n_layers=1)
optimizer = optax.chain(optax.clip(1), optax.adam(1e-2))
opt_state = optimizer.init(policy_net)
def compute_loss(policy_net, obs: hax.NamedArray, act: hax.NamedArray, weights: hax.NamedArray):
logits = policy_net(obs)
logprobs = hax.take(logits, Action, act) - hax.log(hax.sum(hax.exp(logits), Action) + 1e-5)
return -hax.mean(weights * logprobs, axis=Batch).scalar()
@eqx.filter_jit
def do_train_step(model, opt_state, observations, actions, weights):
loss, grad = eqx.filter_value_and_grad(compute_loss)(model, observations, actions, weights)
updates, opt_state = optimizer.update(grad, opt_state)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
mean_return_by_ep = []
for i in range(500):
observations = []
actions = []
weights = []
returns = []
lengths = []
episode_rewards = []
obs, info = env.reset()
done = False
while True:
observations.append(obs)
act = policy_net(hax.named(jnp.array(obs), (Observation,)))
#print(act)
sample, rk = jax.random.split(sample)
act = hax.random.categorical(rk, act, Action).item()
obs, rew, done, truncated, info = env.step(act)
actions.append(act)
episode_rewards.append(rew)
if done:
episode_return = sum(episode_rewards)
returns.append(episode_return)
lengths.append(len(episode_rewards))
weights += [episode_return] * len(episode_rewards)
episode_rewards = []
obs, info = env.reset()
done = False
if len(observations) > Batch.size:
break
observations = hax.named(jnp.array(observations[:Batch.size]), (Batch, Observation))
actions = hax.named(jnp.array(actions[:Batch.size]), (Batch,))
weights = hax.named(jnp.array(weights[:Batch.size]), (Batch,))
weights = (weights - hax.mean(weights, axis=Batch)) / (hax.std(weights, axis=Batch) + 1e-5)
policy_net, opt_state, loss = do_train_step(policy_net, opt_state, observations, actions, weights)
print(i, loss, np.mean(returns), np.mean(lengths))
mean_return_by_ep.append(np.mean(returns))
plt.plot(mean_return_by_ep)
plt.show()
run()