mirror of
https://github.com/osmarks/rl-testing.git
synced 2024-11-13 07:39:53 +00:00
initial commit
This commit is contained in:
commit
94a22e78c4
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
venv
|
2
deploy.sh
Executable file
2
deploy.sh
Executable file
@ -0,0 +1,2 @@
|
||||
#!/bin/sh
|
||||
rsync -vr {requirements.txt,src/*.py} protagonism:rl-testing
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
optax==0.2.3
|
||||
jax<0.5
|
||||
gymnasium==0.29
|
||||
equinox
|
||||
jaxtyping
|
||||
haliax
|
||||
gymnax
|
110
src/run.py
Normal file
110
src/run.py
Normal file
@ -0,0 +1,110 @@
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import haliax as hax
|
||||
import numpy as np
|
||||
import haliax.nn as hnn
|
||||
import gymnasium as gym
|
||||
import optax
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
class MLP(eqx.Module):
|
||||
input: hnn.Linear
|
||||
hidden: [hnn.Linear]
|
||||
output: hnn.Linear
|
||||
|
||||
@staticmethod
|
||||
def init(Input, Hidden, Output, key, n_layers):
|
||||
i, o, *hs = jax.random.split(key, n_layers + 2)
|
||||
return MLP(
|
||||
input=hnn.Linear.init(In=Input, Out=Hidden.alias("hidden0"), key=i),
|
||||
hidden=[hnn.Linear.init(In=Hidden.alias(f"hidden{i}"), Out=Hidden.alias(f"hidden{i+1}"), key=h) for i, h in enumerate(hs)],
|
||||
output=hnn.Linear.init(In=Hidden.alias(f"hidden{n_layers}"), Out=Output, key=o)
|
||||
)
|
||||
|
||||
@eqx.filter_jit
|
||||
def __call__(self, x):
|
||||
x = self.input(x)
|
||||
for h in self.hidden:
|
||||
x = hnn.relu(x)
|
||||
x = h(x)
|
||||
return self.output(x)
|
||||
|
||||
env = gym.make("LunarLander-v2")
|
||||
|
||||
Observation = hax.Axis("observation", env.observation_space.shape[0])
|
||||
Action = hax.Axis("action", int(env.action_space.n))
|
||||
Hidden = hax.Axis("hidden", 32)
|
||||
Batch = hax.Axis("batch", 5000)
|
||||
|
||||
def run():
|
||||
params, sample = jax.random.split(jax.random.PRNGKey(0))
|
||||
policy_net = MLP.init(Input=Observation, Hidden=Hidden, Output=Action, key=params, n_layers=1)
|
||||
|
||||
optimizer = optax.chain(optax.clip(1), optax.adam(1e-2))
|
||||
opt_state = optimizer.init(policy_net)
|
||||
|
||||
def compute_loss(policy_net, obs: hax.NamedArray, act: hax.NamedArray, weights: hax.NamedArray):
|
||||
logits = policy_net(obs)
|
||||
logprobs = hax.take(logits, Action, act) - hax.log(hax.sum(hax.exp(logits), Action) + 1e-5)
|
||||
return -hax.mean(weights * logprobs, axis=Batch).scalar()
|
||||
|
||||
@eqx.filter_jit
|
||||
def do_train_step(model, opt_state, observations, actions, weights):
|
||||
loss, grad = eqx.filter_value_and_grad(compute_loss)(model, observations, actions, weights)
|
||||
updates, opt_state = optimizer.update(grad, opt_state)
|
||||
model = eqx.apply_updates(model, updates)
|
||||
return model, opt_state, loss
|
||||
|
||||
mean_return_by_ep = []
|
||||
|
||||
for i in range(500):
|
||||
observations = []
|
||||
actions = []
|
||||
weights = []
|
||||
returns = []
|
||||
lengths = []
|
||||
|
||||
episode_rewards = []
|
||||
|
||||
obs, info = env.reset()
|
||||
|
||||
done = False
|
||||
|
||||
while True:
|
||||
observations.append(obs)
|
||||
|
||||
act = policy_net(hax.named(jnp.array(obs), (Observation,)))
|
||||
#print(act)
|
||||
sample, rk = jax.random.split(sample)
|
||||
act = hax.random.categorical(rk, act, Action).item()
|
||||
obs, rew, done, truncated, info = env.step(act)
|
||||
|
||||
actions.append(act)
|
||||
episode_rewards.append(rew)
|
||||
|
||||
if done:
|
||||
episode_return = sum(episode_rewards)
|
||||
returns.append(episode_return)
|
||||
lengths.append(len(episode_rewards))
|
||||
weights += [episode_return] * len(episode_rewards)
|
||||
episode_rewards = []
|
||||
obs, info = env.reset()
|
||||
done = False
|
||||
|
||||
if len(observations) > Batch.size:
|
||||
break
|
||||
|
||||
observations = hax.named(jnp.array(observations[:Batch.size]), (Batch, Observation))
|
||||
actions = hax.named(jnp.array(actions[:Batch.size]), (Batch,))
|
||||
weights = hax.named(jnp.array(weights[:Batch.size]), (Batch,))
|
||||
weights = (weights - hax.mean(weights, axis=Batch)) / (hax.std(weights, axis=Batch) + 1e-5)
|
||||
policy_net, opt_state, loss = do_train_step(policy_net, opt_state, observations, actions, weights)
|
||||
print(i, loss, np.mean(returns), np.mean(lengths))
|
||||
|
||||
mean_return_by_ep.append(np.mean(returns))
|
||||
|
||||
plt.plot(mean_return_by_ep)
|
||||
plt.show()
|
||||
|
||||
run()
|
Loading…
Reference in New Issue
Block a user