diff --git a/app/preprocessing/transform_dataset.py b/app/preprocessing/transform_dataset.py index 28c97c1f3f54ec746e84cfceef856dd3b073f3b0..899e74f21f21c2eb853dcd194a884c5303abc193 100644 --- a/app/preprocessing/transform_dataset.py +++ b/app/preprocessing/transform_dataset.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Tuple +from typing import Dict, Any, Tuple, List, Set import pandas as pd from pathlib import Path import json @@ -9,6 +9,7 @@ from joblib import Parallel, delayed from functools import partial import psutil import multiprocessing +import pyarrow.parquet from . import utils as preprocessing_utils from .file_type import FileType @@ -51,11 +52,18 @@ def _cast_columns(df: pd.DataFrame, column_type_mapping: Dict[str | int, str]) - return df +def _drop_non_shared_columns(df: pd.DataFrame, shared_columns: Set[str]) -> pd.DataFrame: + columns_to_drop = [column for column in df.columns if str(column) not in shared_columns] + df = df.drop(columns=columns_to_drop) + + return df + def _transform_parquet_file( file: Path, state_id_name_mapping: Dict[int, str], column_name_type_mapping: Dict[str, str], - out_dir: Path + shared_columns: Set[str], + out_dir: Path, ) -> None: pd.set_option('future.no_silent_downcasting', True) filename = file.name @@ -66,6 +74,8 @@ def _transform_parquet_file( print(f'Processing {filename}') df = df.sort_values(by='FrameCounter') + df = _drop_non_shared_columns(df, shared_columns) + # Forward fill df = df.ffill() @@ -81,9 +91,24 @@ def _transform_parquet_file( print(f'Processed {filename}') +def _shared_columns(parquet_files: List[Path]) -> Set[str]: + if len(parquet_files) == 0: + return {} + + shared_columns: Set[str] = set(pyarrow.parquet.read_schema(parquet_files[0]).names) + for file in parquet_files[1:]: + columns = pyarrow.parquet.read_schema(file).names + shared_columns.intersection_update(columns) + + return shared_columns + def transform_dataset(dataset_dir: Path, out_dir: Path, state_description_file: Path, parallelize: bool = True) -> None: preprocessing_utils.recreate_dir(out_dir) + parquet_files = preprocessing_utils.files_from_dataset(dataset_dir, FileType.Parquet) + + shared_columns = _shared_columns(parquet_files) + state_id_name_mapping: Dict[int, str] = None column_name_type_mapping: Dict[str, str] = None with open(state_description_file, 'r') as f: @@ -91,11 +116,10 @@ def transform_dataset(dataset_dir: Path, out_dir: Path, state_description_file: state_id_name_mapping = {v['stateId']: k for k, v in state_descriptions.items()} column_name_type_mapping = {k: v['dataType'] for k, v in state_descriptions.items()} - parquet_files = preprocessing_utils.files_from_dataset(dataset_dir, FileType.Parquet) - _transform_parquet_file_function_with_args = partial(_transform_parquet_file, state_id_name_mapping=state_id_name_mapping, column_name_type_mapping=column_name_type_mapping, + shared_columns=shared_columns, out_dir=out_dir, )