diff --git a/app/preprocessing/transform_dataset.py b/app/preprocessing/transform_dataset.py index cbc0cb872fd6b1b6186a647d186d28016f71a830..dbb1834849fe2695de872babe666b6141258800f 100644 --- a/app/preprocessing/transform_dataset.py +++ b/app/preprocessing/transform_dataset.py @@ -55,7 +55,7 @@ def _cast_columns(df: pd.DataFrame, column_type_mapping: Dict[str | int, str]) - return df def _split_array_column(df: pd.DataFrame) -> pd.DataFrame: - array_columns = [col for col in df.columns if isinstance(df[col].values[0], np.ndarray)] # Data is consistent in each row + array_columns = [col for col in df.columns if preprocessing_utils.is_column_of_type(df[col], np.ndarray)] for column in array_columns: array_dtype = df[column].iloc[0].dtype # First row must have a value stacked_arrays = np.stack(df[column].values, dtype=array_dtype) # is faster than df[column].apply(lambda vec: pd.Series(vec, dtype=array_dtype)) @@ -66,6 +66,10 @@ def _split_array_column(df: pd.DataFrame) -> pd.DataFrame: return df +def _remove_string_columns(df: pd.DataFrame) -> pd.DataFrame: + string_columns = [col for col in df.columns if preprocessing_utils.is_column_of_type(df[col], str)] + return df.drop(columns=string_columns) + def _drop_non_shared_columns(df: pd.DataFrame, shared_columns: Set[str]) -> pd.DataFrame: columns_to_drop = [column for column in df.columns if str(column) not in shared_columns] df = df.drop(columns=columns_to_drop) @@ -102,6 +106,9 @@ def _transform_parquet_file( # Split arrays df = _split_array_column(df) + # Drop string columns + df = _remove_string_columns(df) + print(f'Saving {filename}') df.to_parquet(out_dir / filename) # df.to_csv(out_dir / f'{file.stem}.csv') diff --git a/app/preprocessing/utils.py b/app/preprocessing/utils.py index 56d1250e663f3289054faf110c3a50a9daa44fd6..3372b03eedf8853d5f40cd5b631e5a9d1398422e 100644 --- a/app/preprocessing/utils.py +++ b/app/preprocessing/utils.py @@ -2,6 +2,8 @@ from typing import List import shutil import os from pathlib import Path +import pandas as pd + from .file_type import FileType def recreate_dir(dir: Path) -> None: @@ -12,3 +14,6 @@ def recreate_dir(dir: Path) -> None: def files_from_dataset(dataset_dir: Path, dataType: FileType): return [path for path in dataset_dir.glob(f'*{dataType.file_extension}') if path.is_file()] + +def is_column_of_type(column: pd.Series, type: type): + return isinstance(column.values[0], type)