Skip to content

Commit

Permalink
Fix destreceiver api
Browse files Browse the repository at this point in the history
  • Loading branch information
aykut-bozkurt committed Jan 20, 2025
1 parent 30fb0c6 commit f4264d4
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 97 deletions.
2 changes: 1 addition & 1 deletion src/parquet_copy_hook/copy_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub(crate) fn execute_copy_to_with_dest_receiver(
query_string: &CStr,
params: &PgBox<ParamListInfoData>,
query_env: &PgBox<QueryEnvironment>,
parquet_dest: PgBox<DestReceiver>,
parquet_dest: &PgBox<DestReceiver>,
) -> u64 {
unsafe {
debug_assert!(is_a(p_stmt.utilityStmt, T_CopyStmt));
Expand Down
128 changes: 48 additions & 80 deletions src/parquet_copy_hook/copy_to_dest_receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,10 @@ use pg_sys::{
};
use pgrx::{prelude::*, FromDatum, PgList, PgMemoryContexts, PgTupleDesc};

use crate::{
arrow_parquet::{
compression::{PgParquetCompression, INVALID_COMPRESSION_LEVEL},
parquet_writer::{
ParquetWriterContext, DEFAULT_ROW_GROUP_SIZE, DEFAULT_ROW_GROUP_SIZE_BYTES,
},
uri_utils::parse_uri,
},
pgrx_utils::{collect_attributes_for, CollectAttributesFor},
use crate::arrow_parquet::{
compression::{PgParquetCompression, INVALID_COMPRESSION_LEVEL},
parquet_writer::{ParquetWriterContext, DEFAULT_ROW_GROUP_SIZE, DEFAULT_ROW_GROUP_SIZE_BYTES},
uri_utils::parse_uri,
};

#[repr(C)]
Expand All @@ -40,6 +35,7 @@ struct CopyToParquetDestReceiver {
uri: *const c_char,
copy_options: CopyToParquetOptions,
per_copy_context: MemoryContext,
parquet_writer_context: *mut ParquetWriterContext,
}

impl CopyToParquetDestReceiver {
Expand Down Expand Up @@ -102,7 +98,7 @@ impl CopyToParquetDestReceiver {
fn write_tuples_to_parquet(&mut self) {
debug_assert!(!self.tupledesc.is_null());

let tupledesc = unsafe { PgTupleDesc::from_pg(self.tupledesc) };
let tupledesc = unsafe { PgTupleDesc::from_pg_unchecked(self.tupledesc) };

let tuples = unsafe { PgList::from_pg(self.collected_tuples) };
let tuples = tuples
Expand All @@ -117,57 +113,33 @@ impl CopyToParquetDestReceiver {
})
.collect::<Vec<_>>();

let current_parquet_writer_context =
peek_parquet_writer_context().expect("parquet writer context is not found");
let current_parquet_writer_context = unsafe {
self.parquet_writer_context
.as_mut()
.expect("parquet writer context is not found")
};
current_parquet_writer_context.write_new_row_group(tuples);

self.reset_collected_tuples();
}

fn cleanup(&mut self) {
unsafe { MemoryContextDelete(self.per_copy_context) };
}
}
if !self.parquet_writer_context.is_null() {
unsafe { MemoryContextDelete(self.per_copy_context) };
}

// stack to store parquet writer contexts for COPY TO.
// This needs to be a stack since COPY command can be nested.
static mut PARQUET_WRITER_CONTEXT_STACK: Vec<ParquetWriterContext> = vec![];
if !self.parquet_writer_context.is_null() {
let parquet_writer_context = unsafe { Box::from_raw(self.parquet_writer_context) };

pub(crate) fn peek_parquet_writer_context() -> Option<&'static mut ParquetWriterContext> {
#[allow(static_mut_refs)]
unsafe {
PARQUET_WRITER_CONTEXT_STACK.last_mut()
}
}
self.parquet_writer_context = std::ptr::null_mut();

pub(crate) fn pop_parquet_writer_context(throw_error: bool) {
#[allow(static_mut_refs)]
let mut current_parquet_writer_context = unsafe { PARQUET_WRITER_CONTEXT_STACK.pop() };
drop(parquet_writer_context);
}

if current_parquet_writer_context.is_none() {
let level = if throw_error {
PgLogLevel::ERROR
} else {
PgLogLevel::DEBUG2
};

ereport!(
level,
PgSqlErrorCode::ERRCODE_INTERNAL_ERROR,
"parquet writer context stack is already empty"
);
} else {
current_parquet_writer_context.take();
self.collected_tuple_count = 0;
}
}

pub(crate) fn push_parquet_writer_context(writer_ctx: ParquetWriterContext) {
#[allow(static_mut_refs)]
unsafe {
PARQUET_WRITER_CONTEXT_STACK.push(writer_ctx)
};
}

#[pg_guard]
extern "C" fn copy_startup(dest: *mut DestReceiver, _operation: i32, tupledesc: TupleDesc) {
let parquet_dest = unsafe {
Expand All @@ -178,20 +150,20 @@ extern "C" fn copy_startup(dest: *mut DestReceiver, _operation: i32, tupledesc:

// bless tupledesc, otherwise lookup_row_tupledesc would fail for row types
let tupledesc = unsafe { BlessTupleDesc(tupledesc) };
let tupledesc = unsafe { PgTupleDesc::from_pg(tupledesc) };

let attributes = collect_attributes_for(CollectAttributesFor::CopyTo, &tupledesc);
// from_pg_unchecked makes sure tupledesc is not dropped since it is an external tupledesc
let tupledesc = unsafe { PgTupleDesc::from_pg_unchecked(tupledesc) };

// update the parquet dest receiver's missing fields
parquet_dest.tupledesc = tupledesc.as_ptr();
parquet_dest.collected_tuples = PgList::<HeapTupleData>::new().into_pg();
parquet_dest.collected_tuple_column_sizes = unsafe {
MemoryContextAllocZero(
parquet_dest.per_copy_context,
std::mem::size_of::<i64>() * attributes.len(),
std::mem::size_of::<i64>() * tupledesc.len(),
) as *mut i64
};
parquet_dest.natts = attributes.len();
parquet_dest.natts = tupledesc.len();

let uri = unsafe { CStr::from_ptr(parquet_dest.uri) }
.to_str()
Expand All @@ -203,11 +175,10 @@ extern "C" fn copy_startup(dest: *mut DestReceiver, _operation: i32, tupledesc:

let compression_level = parquet_dest.copy_options.compression_level;

// parquet writer context is used throughout the COPY TO operation.
// This might be put into ParquetCopyDestReceiver, but it's hard to preserve repr(C).
// leak the parquet writer context since it will be used during the COPY operation
let parquet_writer_context =
ParquetWriterContext::new(uri, compression, compression_level, &tupledesc);
push_parquet_writer_context(parquet_writer_context);
parquet_dest.parquet_writer_context = Box::into_raw(Box::new(parquet_writer_context));
}

#[pg_guard]
Expand All @@ -225,23 +196,19 @@ extern "C" fn copy_receive(slot: *mut TupleTableSlot, dest: *mut DestReceiver) -
// extracts all attributes in statement "SELECT * FROM table"
slot_getallattrs(slot);

let slot = PgBox::from_pg(slot);

let natts = parquet_dest.natts;

let datums = slot.tts_values;
let datums = std::slice::from_raw_parts(datums, natts);
let datums = std::slice::from_raw_parts((*slot).tts_values, natts);

let nulls = slot.tts_isnull;
let nulls = std::slice::from_raw_parts(nulls, natts);
let nulls = std::slice::from_raw_parts((*slot).tts_isnull, natts);

let datums: Vec<Option<Datum>> = datums
.iter()
.zip(nulls)
.map(|(datum, is_null)| if *is_null { None } else { Some(*datum) })
.collect();

let tupledesc = PgTupleDesc::from_pg(parquet_dest.tupledesc);
let tupledesc = PgTupleDesc::from_pg_unchecked(parquet_dest.tupledesc);

let column_sizes = tuple_column_sizes(&datums, &tupledesc);

Expand Down Expand Up @@ -278,9 +245,6 @@ extern "C" fn copy_shutdown(dest: *mut DestReceiver) {
}

parquet_dest.cleanup();

let throw_error = true;
pop_parquet_writer_context(throw_error);
}

#[pg_guard]
Expand Down Expand Up @@ -344,6 +308,7 @@ pub extern "C" fn create_copy_to_parquet_dest_receiver(
parquet_dest.dest.mydest = CommandDest::DestCopyOut;
parquet_dest.uri = uri;
parquet_dest.tupledesc = std::ptr::null_mut();
parquet_dest.parquet_writer_context = std::ptr::null_mut();
parquet_dest.natts = 0;
parquet_dest.collected_tuple_count = 0;
parquet_dest.collected_tuples = std::ptr::null_mut();
Expand All @@ -361,29 +326,32 @@ fn tuple_column_sizes(tuple_datums: &[Option<Datum>], tupledesc: &PgTupleDesc) -
let mut column_sizes = vec![];

for (idx, column_datum) in tuple_datums.iter().enumerate() {
if column_datum.is_none() {
column_sizes.push(0);
continue;
}

let column_datum = column_datum.as_ref().expect("column datum is None");

let attribute = tupledesc.get(idx).expect("cannot get attribute");

let typoid = attribute.type_oid();

let mut typlen = -1_i16;
let mut typbyval = false;
unsafe { get_typlenbyval(typoid.value(), &mut typlen, &mut typbyval) };

let column_size = if let Some(column_datum) = column_datum {
if typlen == -1 {
(unsafe { toast_raw_datum_size(*column_datum) }) as i32 - VARHDRSZ as i32
} else if typlen == -2 {
// cstring
let cstring = unsafe {
CString::from_datum(*column_datum, false)
.expect("cannot get cstring from datum")
};
cstring.as_bytes().len() as i32 + 1
} else {
// fixed size type
typlen as i32
}
let column_size = if typlen == -1 {
(unsafe { toast_raw_datum_size(*column_datum) }) as i32 - VARHDRSZ as i32
} else if typlen == -2 {
// cstring
let cstring = unsafe {
CString::from_datum(*column_datum, false).expect("cannot get cstring from datum")
};
cstring.as_bytes().len() as i32 + 1
} else {
0
// fixed size type
typlen as i32
};

column_sizes.push(column_size);
Expand Down
32 changes: 16 additions & 16 deletions src/parquet_copy_hook/hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use crate::{
use super::{
copy_from::{execute_copy_from, pop_parquet_reader_context},
copy_to::execute_copy_to_with_dest_receiver,
copy_to_dest_receiver::pop_parquet_writer_context,
copy_utils::{copy_to_stmt_compression, validate_copy_from_options, validate_copy_to_options},
};

Expand Down Expand Up @@ -62,25 +61,26 @@ fn process_copy_to_parquet(
let compression = copy_to_stmt_compression(p_stmt, uri.clone());
let compression_level = copy_to_stmt_compression_level(p_stmt, uri.clone());

PgTryBuilder::new(|| {
let parquet_dest = create_copy_to_parquet_dest_receiver(
uri_as_string(&uri).as_pg_cstr(),
&row_group_size,
&row_group_size_bytes,
&compression,
&compression_level.unwrap_or(INVALID_COMPRESSION_LEVEL),
);
let parquet_dest = create_copy_to_parquet_dest_receiver(
uri_as_string(&uri).as_pg_cstr(),
&row_group_size,
&row_group_size_bytes,
&compression,
&compression_level.unwrap_or(INVALID_COMPRESSION_LEVEL),
);

let parquet_dest = unsafe { PgBox::from_pg(parquet_dest) };
let parquet_dest = unsafe { PgBox::from_pg(parquet_dest) };

execute_copy_to_with_dest_receiver(p_stmt, query_string, params, query_env, parquet_dest)
PgTryBuilder::new(|| {
execute_copy_to_with_dest_receiver(p_stmt, query_string, params, query_env, &parquet_dest)
})
.catch_others(|cause| {
// make sure to pop the parquet writer context
// In case we did not push the context, we should not throw an error while popping
let throw_error = false;
pop_parquet_writer_context(throw_error);

// make sure to cleanup parquet dest receiver
if let Some(shutdown_callback) = parquet_dest.rShutdown {
unsafe {
shutdown_callback(parquet_dest.as_ptr());
}
}
cause.rethrow()
})
.execute()
Expand Down

0 comments on commit f4264d4

Please sign in to comment.