mirror of
https://github.com/osmarks/rl-testing.git
synced 2024-11-10 22:29:54 +00:00
initial commit
This commit is contained in:
commit
94a22e78c4
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
venv
|
2
deploy.sh
Executable file
2
deploy.sh
Executable file
@ -0,0 +1,2 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
rsync -vr {requirements.txt,src/*.py} protagonism:rl-testing
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
optax==0.2.3
|
||||||
|
jax<0.5
|
||||||
|
gymnasium==0.29
|
||||||
|
equinox
|
||||||
|
jaxtyping
|
||||||
|
haliax
|
||||||
|
gymnax
|
110
src/run.py
Normal file
110
src/run.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import haliax as hax
|
||||||
|
import numpy as np
|
||||||
|
import haliax.nn as hnn
|
||||||
|
import gymnasium as gym
|
||||||
|
import optax
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
class MLP(eqx.Module):
|
||||||
|
input: hnn.Linear
|
||||||
|
hidden: [hnn.Linear]
|
||||||
|
output: hnn.Linear
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init(Input, Hidden, Output, key, n_layers):
|
||||||
|
i, o, *hs = jax.random.split(key, n_layers + 2)
|
||||||
|
return MLP(
|
||||||
|
input=hnn.Linear.init(In=Input, Out=Hidden.alias("hidden0"), key=i),
|
||||||
|
hidden=[hnn.Linear.init(In=Hidden.alias(f"hidden{i}"), Out=Hidden.alias(f"hidden{i+1}"), key=h) for i, h in enumerate(hs)],
|
||||||
|
output=hnn.Linear.init(In=Hidden.alias(f"hidden{n_layers}"), Out=Output, key=o)
|
||||||
|
)
|
||||||
|
|
||||||
|
@eqx.filter_jit
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.input(x)
|
||||||
|
for h in self.hidden:
|
||||||
|
x = hnn.relu(x)
|
||||||
|
x = h(x)
|
||||||
|
return self.output(x)
|
||||||
|
|
||||||
|
env = gym.make("LunarLander-v2")
|
||||||
|
|
||||||
|
Observation = hax.Axis("observation", env.observation_space.shape[0])
|
||||||
|
Action = hax.Axis("action", int(env.action_space.n))
|
||||||
|
Hidden = hax.Axis("hidden", 32)
|
||||||
|
Batch = hax.Axis("batch", 5000)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
params, sample = jax.random.split(jax.random.PRNGKey(0))
|
||||||
|
policy_net = MLP.init(Input=Observation, Hidden=Hidden, Output=Action, key=params, n_layers=1)
|
||||||
|
|
||||||
|
optimizer = optax.chain(optax.clip(1), optax.adam(1e-2))
|
||||||
|
opt_state = optimizer.init(policy_net)
|
||||||
|
|
||||||
|
def compute_loss(policy_net, obs: hax.NamedArray, act: hax.NamedArray, weights: hax.NamedArray):
|
||||||
|
logits = policy_net(obs)
|
||||||
|
logprobs = hax.take(logits, Action, act) - hax.log(hax.sum(hax.exp(logits), Action) + 1e-5)
|
||||||
|
return -hax.mean(weights * logprobs, axis=Batch).scalar()
|
||||||
|
|
||||||
|
@eqx.filter_jit
|
||||||
|
def do_train_step(model, opt_state, observations, actions, weights):
|
||||||
|
loss, grad = eqx.filter_value_and_grad(compute_loss)(model, observations, actions, weights)
|
||||||
|
updates, opt_state = optimizer.update(grad, opt_state)
|
||||||
|
model = eqx.apply_updates(model, updates)
|
||||||
|
return model, opt_state, loss
|
||||||
|
|
||||||
|
mean_return_by_ep = []
|
||||||
|
|
||||||
|
for i in range(500):
|
||||||
|
observations = []
|
||||||
|
actions = []
|
||||||
|
weights = []
|
||||||
|
returns = []
|
||||||
|
lengths = []
|
||||||
|
|
||||||
|
episode_rewards = []
|
||||||
|
|
||||||
|
obs, info = env.reset()
|
||||||
|
|
||||||
|
done = False
|
||||||
|
|
||||||
|
while True:
|
||||||
|
observations.append(obs)
|
||||||
|
|
||||||
|
act = policy_net(hax.named(jnp.array(obs), (Observation,)))
|
||||||
|
#print(act)
|
||||||
|
sample, rk = jax.random.split(sample)
|
||||||
|
act = hax.random.categorical(rk, act, Action).item()
|
||||||
|
obs, rew, done, truncated, info = env.step(act)
|
||||||
|
|
||||||
|
actions.append(act)
|
||||||
|
episode_rewards.append(rew)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
episode_return = sum(episode_rewards)
|
||||||
|
returns.append(episode_return)
|
||||||
|
lengths.append(len(episode_rewards))
|
||||||
|
weights += [episode_return] * len(episode_rewards)
|
||||||
|
episode_rewards = []
|
||||||
|
obs, info = env.reset()
|
||||||
|
done = False
|
||||||
|
|
||||||
|
if len(observations) > Batch.size:
|
||||||
|
break
|
||||||
|
|
||||||
|
observations = hax.named(jnp.array(observations[:Batch.size]), (Batch, Observation))
|
||||||
|
actions = hax.named(jnp.array(actions[:Batch.size]), (Batch,))
|
||||||
|
weights = hax.named(jnp.array(weights[:Batch.size]), (Batch,))
|
||||||
|
weights = (weights - hax.mean(weights, axis=Batch)) / (hax.std(weights, axis=Batch) + 1e-5)
|
||||||
|
policy_net, opt_state, loss = do_train_step(policy_net, opt_state, observations, actions, weights)
|
||||||
|
print(i, loss, np.mean(returns), np.mean(lengths))
|
||||||
|
|
||||||
|
mean_return_by_ep.append(np.mean(returns))
|
||||||
|
|
||||||
|
plt.plot(mean_return_by_ep)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
run()
|
Loading…
Reference in New Issue
Block a user