diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000000000000000000000000000000000..9c920976651eeae7ff4e30326bac079ca960d75e --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "main.py", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} \ No newline at end of file diff --git a/HelloWorldEnv.py b/HelloWorldEnv.py new file mode 100644 index 0000000000000000000000000000000000000000..fc09bbf7ddb0aef471ab5d3d094d70b4d6c558ce --- /dev/null +++ b/HelloWorldEnv.py @@ -0,0 +1,59 @@ +from typing import Text +import airsim +import numpy as np +import HelloWorldEnv + +from tf_agents.environments.py_environment import PyEnvironment +from tf_agents.environments.tf_py_environment import TFPyEnvironment +from tf_agents.trajectories import time_step as ts +from tf_agents.typing import types +from tf_agents.typing.types import NestedArraySpec, TimeStep, NestedArray +from tf_agents.specs.array_spec import BoundedArraySpec, ArraySpec + +class HelloWorldEnv(PyEnvironment): + def __init__(self, ip: str, handle_auto_reset: bool = False): + self._action_spec = BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=2, name='action') + self._observation_spec = ArraySpec(shape=(3,), dtype=np.float32, name='observation') + self._client = airsim.MultirotorClient(ip=ip) + super(HelloWorldEnv, self).__init__(handle_auto_reset) + + def action_spec(self) -> NestedArraySpec: + return self._action_spec + + def observation_spec(self) -> NestedArraySpec: + return self._observation_spec + + def _reset(self) -> TimeStep: + self._client.reset() + self._client.confirmConnection() + self._client.armDisarm(True) + self._client.enableApiControl(True) + return ts.restart(self._getObservation()) + + def _step(self, action: NestedArray) -> TimeStep: + vz = 0 + if action == 1: # up + vz = -1 + elif action == 2: # down + vz = 1 + elif action == 0: # stop + vz == 0 + + self._client.cancelLastTask() + self._client.moveByVelocityAsync(0, 0, vz, 5) # max 5 seconds flight without a new command until the drone stops + + obs = self._getObservation() + reward = HelloWorldEnv._calculate_reward(obs) + + return ts.transition(obs, reward=reward) + + def render(self, mode: Text = 'rgb_array') -> NestedArray | None: + raise NotImplementedError() + + def _getObservation(self): + geo_point = self._client.getGpsData().gnss.geo_point + return geo_point.latitude, geo_point.longitude, geo_point.altitude + + @staticmethod + def _calculate_reward(obs: tuple[float, float, float]): + return -abs(133 - obs[2]) diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1b3e4111e9f42407678687d61ea82f07b6129c --- /dev/null +++ b/main.py @@ -0,0 +1,10 @@ +from HelloWorldEnv import HelloWorldEnv +from time import sleep + +SIM_IP = "192.168.8.195" + +env = HelloWorldEnv(ip=SIM_IP) +env.reset() +env.step(1) +sleep(2) +env.step(1) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..732b5a9f9b75f4e5bed2efca2b3845612f28fe43 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +nvidia-cudnn-cu11==8.6.0.163 +tensorflow==2.13.* +nvidia-tensorrt +msgpack-rpc-python +airsim +tf-agents