diff --git a/export_import/tf_lite_export.py b/export_import/tf_lite_export.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f6e0ef24dbcd10e00a5d472cf6ea336449684d --- /dev/null +++ b/export_import/tf_lite_export.py @@ -0,0 +1,17 @@ +import tensorflow as tf +import sys +import os +import tensorflow_probability as tfp # fails otherwise + +policies_dir = "out/policies" +policy_name = sys.argv[1] +policy_dir = os.path.join(policies_dir, policy_name) + +converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"]) +tflite_policy = converter.convert() + +tflite_dir = os.path.join(policies_dir, '{}.tflite'.format(policy_name)) +with open(tflite_dir, 'wb') as f: + f.write(tflite_policy) + +print(tflite_dir)