From 66876851b9f300814d8242148f51524c9fb18a8e Mon Sep 17 00:00:00 2001
From: Andri Joos <andri@joos.io>
Date: Sat, 9 Nov 2024 13:00:14 +0100
Subject: [PATCH] drop string columns

---
 app/preprocessing/transform_dataset.py | 9 ++++++++-
 app/preprocessing/utils.py             | 5 +++++
 2 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/app/preprocessing/transform_dataset.py b/app/preprocessing/transform_dataset.py
index cbc0cb8..dbb1834 100644
--- a/app/preprocessing/transform_dataset.py
+++ b/app/preprocessing/transform_dataset.py
@@ -55,7 +55,7 @@ def _cast_columns(df: pd.DataFrame, column_type_mapping: Dict[str | int, str]) -
     return df
 
 def _split_array_column(df: pd.DataFrame) -> pd.DataFrame:
-    array_columns = [col for col in df.columns if isinstance(df[col].values[0], np.ndarray)] # Data is consistent in each row
+    array_columns = [col for col in df.columns if preprocessing_utils.is_column_of_type(df[col], np.ndarray)]
     for column in array_columns:
         array_dtype = df[column].iloc[0].dtype # First row must have a value
         stacked_arrays = np.stack(df[column].values, dtype=array_dtype) # is faster than df[column].apply(lambda vec: pd.Series(vec, dtype=array_dtype))
@@ -66,6 +66,10 @@ def _split_array_column(df: pd.DataFrame) -> pd.DataFrame:
 
     return df
 
+def _remove_string_columns(df: pd.DataFrame) -> pd.DataFrame:
+    string_columns = [col for col in df.columns if preprocessing_utils.is_column_of_type(df[col], str)]
+    return df.drop(columns=string_columns)
+
 def _drop_non_shared_columns(df: pd.DataFrame, shared_columns: Set[str]) -> pd.DataFrame:
     columns_to_drop = [column for column in df.columns if str(column) not in shared_columns]
     df = df.drop(columns=columns_to_drop)
@@ -102,6 +106,9 @@ def _transform_parquet_file(
     # Split arrays
     df = _split_array_column(df)
 
+    # Drop string columns
+    df = _remove_string_columns(df)
+
     print(f'Saving {filename}')
     df.to_parquet(out_dir / filename)
     # df.to_csv(out_dir / f'{file.stem}.csv')
diff --git a/app/preprocessing/utils.py b/app/preprocessing/utils.py
index 56d1250..3372b03 100644
--- a/app/preprocessing/utils.py
+++ b/app/preprocessing/utils.py
@@ -2,6 +2,8 @@ from typing import List
 import shutil
 import os
 from pathlib import Path
+import pandas as pd
+
 from .file_type import FileType
 
 def recreate_dir(dir: Path) -> None:
@@ -12,3 +14,6 @@ def recreate_dir(dir: Path) -> None:
 
 def files_from_dataset(dataset_dir: Path, dataType: FileType):
     return [path for path in dataset_dir.glob(f'*{dataType.file_extension}') if path.is_file()]
+
+def is_column_of_type(column: pd.Series, type: type):
+    return isinstance(column.values[0], type)
-- 
GitLab