Skip to content
Snippets Groups Projects
Commit a8e4c8f4 authored by Andri Joos's avatar Andri Joos :blush:
Browse files

add environment

parent 0c672e90
No related branches found
No related tags found
1 merge request!1Resolve "Setup Environment"
from typing import Text
import airsim
import numpy as np
import HelloWorldEnv
from tf_agents.environments.py_environment import PyEnvironment
from tf_agents.environments.tf_py_environment import TFPyEnvironment
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.typing.types import NestedArraySpec, TimeStep, NestedArray
from tf_agents.specs.array_spec import BoundedArraySpec, ArraySpec
class HelloWorldEnv(PyEnvironment):
def __init__(self, ip: str, handle_auto_reset: bool = False):
self._action_spec = BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=2, name='action')
self._observation_spec = ArraySpec(shape=(3,), dtype=np.float32, name='observation')
self._client = airsim.MultirotorClient(ip=ip)
super(HelloWorldEnv, self).__init__(handle_auto_reset)
def action_spec(self) -> NestedArraySpec:
return self._action_spec
def observation_spec(self) -> NestedArraySpec:
return self._observation_spec
def _reset(self) -> TimeStep:
self._client.reset()
self._client.confirmConnection()
self._client.armDisarm(True)
self._client.enableApiControl(True)
return ts.restart(self._getObservation())
def _step(self, action: NestedArray) -> TimeStep:
vz = 0
if action == 1: # up
vz = -1
elif action == 2: # down
vz = 1
elif action == 0: # stop
vz == 0
self._client.cancelLastTask()
self._client.moveByVelocityAsync(0, 0, vz, 5) # max 5 seconds flight without a new command until the drone stops
obs = self._getObservation()
reward = HelloWorldEnv._calculate_reward(obs)
return ts.transition(obs, reward=reward)
def render(self, mode: Text = 'rgb_array') -> NestedArray | None:
raise NotImplementedError()
def _getObservation(self):
geo_point = self._client.getGpsData().gnss.geo_point
return geo_point.latitude, geo_point.longitude, geo_point.altitude
@staticmethod
def _calculate_reward(obs: tuple[float, float, float]):
return -abs(133 - obs[2])
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment