From 94a22e78c4d57eb82bfcf60678184cc0204fa758 Mon Sep 17 00:00:00 2001 From: osmarks Date: Mon, 29 Jul 2024 13:40:57 +0100 Subject: [PATCH] initial commit --- .gitignore | 1 + deploy.sh | 2 + requirements.txt | 7 +++ src/run.py | 110 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+) create mode 100644 .gitignore create mode 100755 deploy.sh create mode 100644 requirements.txt create mode 100644 src/run.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f5e96db --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +venv \ No newline at end of file diff --git a/deploy.sh b/deploy.sh new file mode 100755 index 0000000..29cff88 --- /dev/null +++ b/deploy.sh @@ -0,0 +1,2 @@ +#!/bin/sh +rsync -vr {requirements.txt,src/*.py} protagonism:rl-testing \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..815b11b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +optax==0.2.3 +jax<0.5 +gymnasium==0.29 +equinox +jaxtyping +haliax +gymnax \ No newline at end of file diff --git a/src/run.py b/src/run.py new file mode 100644 index 0000000..59af4ef --- /dev/null +++ b/src/run.py @@ -0,0 +1,110 @@ +import equinox as eqx +import jax +import jax.numpy as jnp +import haliax as hax +import numpy as np +import haliax.nn as hnn +import gymnasium as gym +import optax +import matplotlib.pyplot as plt + +class MLP(eqx.Module): + input: hnn.Linear + hidden: [hnn.Linear] + output: hnn.Linear + + @staticmethod + def init(Input, Hidden, Output, key, n_layers): + i, o, *hs = jax.random.split(key, n_layers + 2) + return MLP( + input=hnn.Linear.init(In=Input, Out=Hidden.alias("hidden0"), key=i), + hidden=[hnn.Linear.init(In=Hidden.alias(f"hidden{i}"), Out=Hidden.alias(f"hidden{i+1}"), key=h) for i, h in enumerate(hs)], + output=hnn.Linear.init(In=Hidden.alias(f"hidden{n_layers}"), Out=Output, key=o) + ) + + @eqx.filter_jit + def __call__(self, x): + x = self.input(x) + for h in self.hidden: + x = hnn.relu(x) + x = h(x) + return self.output(x) + +env = gym.make("LunarLander-v2") + +Observation = hax.Axis("observation", env.observation_space.shape[0]) +Action = hax.Axis("action", int(env.action_space.n)) +Hidden = hax.Axis("hidden", 32) +Batch = hax.Axis("batch", 5000) + +def run(): + params, sample = jax.random.split(jax.random.PRNGKey(0)) + policy_net = MLP.init(Input=Observation, Hidden=Hidden, Output=Action, key=params, n_layers=1) + + optimizer = optax.chain(optax.clip(1), optax.adam(1e-2)) + opt_state = optimizer.init(policy_net) + + def compute_loss(policy_net, obs: hax.NamedArray, act: hax.NamedArray, weights: hax.NamedArray): + logits = policy_net(obs) + logprobs = hax.take(logits, Action, act) - hax.log(hax.sum(hax.exp(logits), Action) + 1e-5) + return -hax.mean(weights * logprobs, axis=Batch).scalar() + + @eqx.filter_jit + def do_train_step(model, opt_state, observations, actions, weights): + loss, grad = eqx.filter_value_and_grad(compute_loss)(model, observations, actions, weights) + updates, opt_state = optimizer.update(grad, opt_state) + model = eqx.apply_updates(model, updates) + return model, opt_state, loss + + mean_return_by_ep = [] + + for i in range(500): + observations = [] + actions = [] + weights = [] + returns = [] + lengths = [] + + episode_rewards = [] + + obs, info = env.reset() + + done = False + + while True: + observations.append(obs) + + act = policy_net(hax.named(jnp.array(obs), (Observation,))) + #print(act) + sample, rk = jax.random.split(sample) + act = hax.random.categorical(rk, act, Action).item() + obs, rew, done, truncated, info = env.step(act) + + actions.append(act) + episode_rewards.append(rew) + + if done: + episode_return = sum(episode_rewards) + returns.append(episode_return) + lengths.append(len(episode_rewards)) + weights += [episode_return] * len(episode_rewards) + episode_rewards = [] + obs, info = env.reset() + done = False + + if len(observations) > Batch.size: + break + + observations = hax.named(jnp.array(observations[:Batch.size]), (Batch, Observation)) + actions = hax.named(jnp.array(actions[:Batch.size]), (Batch,)) + weights = hax.named(jnp.array(weights[:Batch.size]), (Batch,)) + weights = (weights - hax.mean(weights, axis=Batch)) / (hax.std(weights, axis=Batch) + 1e-5) + policy_net, opt_state, loss = do_train_step(policy_net, opt_state, observations, actions, weights) + print(i, loss, np.mean(returns), np.mean(lengths)) + + mean_return_by_ep.append(np.mean(returns)) + + plt.plot(mean_return_by_ep) + plt.show() + +run() \ No newline at end of file