From 0608cd70979fff4586c174ee1bb6e04a913b1609 Mon Sep 17 00:00:00 2001 From: Andri Joos <andri@joos.io> Date: Fri, 29 Sep 2023 17:42:50 +0200 Subject: [PATCH] remove unused infos from env --- HelloWorldEnv.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/HelloWorldEnv.py b/HelloWorldEnv.py index 154e135..fc57073 100644 --- a/HelloWorldEnv.py +++ b/HelloWorldEnv.py @@ -13,7 +13,7 @@ from tf_agents.specs.array_spec import BoundedArraySpec, ArraySpec class HelloWorldEnv(PyEnvironment): def __init__(self, ip: str, handle_auto_reset: bool = False): self._action_spec = BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=2, name='action') - self._observation_spec = ArraySpec(shape=(3,), dtype=np.float32, name='observation') + self._observation_spec = ArraySpec(shape=(1,), dtype=np.float32, name='observation') self._client = airsim.MultirotorClient(ip=ip) super(HelloWorldEnv, self).__init__(handle_auto_reset) @@ -52,8 +52,10 @@ class HelloWorldEnv(PyEnvironment): def _getObservation(self): geo_point = self._client.getGpsData().gnss.geo_point - return geo_point.latitude, geo_point.longitude, geo_point.altitude + # return geo_point.latitude, geo_point.longitude, geo_point.altitude + return (geo_point.altitude,) @staticmethod def _calculate_reward(obs: tuple[float, float, float]): - return -abs(133 - obs[2]) + # return -abs(133 - obs[2]) + return -abs(133 - (obs[0])) -- GitLab