diff --git a/CHANGELOG.md b/CHANGELOG.md index fe4ffca070..8dc835bef8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ All notable changes to this project will be documented in this file. - The `code` and `file` fields on the `javascript` processor docs no longer erroneously mention interpolation support. (@mihaitodor) - The `postgres_cdc` now correctly handles `null` values. (@rockwotj) - Fix an issue in `aws_sqs` with refreshing in-flight message leases which could prevent acks from processed. (@rockwotj) +- Fix an issue with `postgres_cdc` with TOAST values not being propagated with `REPLICA IDENTITY FULL`. (@rockwotj) +- Fix a initial snapshot streaming consistency issue with `postgres_cdc`. (@rockwotj) ## 4.44.0 - 2024-12-13 diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 337b25315f..fbf4ab94cc 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -325,9 +325,11 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher }) monitorLoop.Start() defer monitorLoop.Stop() - ctx, _ := p.stopSig.SoftStopCtx(context.Background()) + ctx, cancel := p.stopSig.SoftStopCtx(context.Background()) + defer cancel() defer func() { - ctx, _ := p.stopSig.HardStopCtx(context.Background()) + ctx, cancel := p.stopSig.HardStopCtx(context.Background()) + defer cancel() if err := batcher.Close(ctx); err != nil { p.logger.Errorf("unable to close batcher: %s", err) } diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index edf9f465cd..fd553ba7b6 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -137,6 +137,11 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version return err } + _, err = db.Exec("CREATE TABLE IF NOT EXISTS large_values (id serial PRIMARY KEY, value TEXT);") + if err != nil { + return err + } + // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") @@ -757,3 +762,106 @@ file: }) } } + +func TestIntegrationTOASTValues(t *testing.T) { + t.Parallel() + integration.CheckSkip(t) + tmpDir := t.TempDir() + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + var ( + resource *dockertest.Resource + db *sql.DB + ) + + resource, db, err = ResourceWithPostgreSQLVersion(t, pool, "16") + require.NoError(t, err) + require.NoError(t, resource.Expire(120)) + + _, err = db.Exec(`ALTER TABLE large_values REPLICA IDENTITY FULL;`) + require.NoError(t, err) + + const stringSize = 400_000 + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + password := "l]YLSc|4[i56%{gY" + + require.NoError(t, err) + + // Insert a large >1MiB value + _, err = db.Exec(`INSERT INTO large_values (id, value) VALUES ($1, $2);`, 1, strings.Repeat("foo", stringSize)) + require.NoError(t, err) + + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) + template := fmt.Sprintf(` +pg_stream: + dsn: %s + slot_name: test_slot_native_decoder + stream_snapshot: true + snapshot_batch_size: 1 + schema: public + tables: + - large_values +`, databaseURL) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: TRACE`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outBatches []string + var outBatchMut sync.Mutex + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + msgBytes, err := mb[0].AsBytes() + require.NoError(t, err) + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + license.InjectTestService(streamOut.Resources()) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 1 + }, time.Second*10, time.Millisecond*100) + + _, err = db.Exec(`UPDATE large_values SET value=$1;`, strings.Repeat("bar", stringSize)) + require.NoError(t, err) + _, err = db.Exec(`UPDATE large_values SET id=$1;`, 3) + require.NoError(t, err) + _, err = db.Exec(`DELETE FROM large_values`) + require.NoError(t, err) + _, err = db.Exec(`INSERT INTO large_values (id, value) VALUES ($1, $2);`, 2, strings.Repeat("qux", stringSize)) + require.NoError(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + outBatchMut.Lock() + defer outBatchMut.Unlock() + assert.Len(c, outBatches, 5, "got: %#v", outBatches) + }, time.Second*10, time.Millisecond*100) + require.JSONEq(t, `{"id":1, "value": "`+strings.Repeat("foo", stringSize)+`"}`, outBatches[0], "GOT: %s", outBatches[0]) + require.JSONEq(t, `{"id":1, "value": "`+strings.Repeat("bar", stringSize)+`"}`, outBatches[1], "GOT: %s", outBatches[1]) + require.JSONEq(t, `{"id":3, "value": "`+strings.Repeat("bar", stringSize)+`"}`, outBatches[2], "GOT: %s", outBatches[2]) + require.JSONEq(t, `{"id":3, "value": "`+strings.Repeat("bar", stringSize)+`"}`, outBatches[3], "GOT: %s", outBatches[3]) + require.JSONEq(t, `{"id":2, "value": "`+strings.Repeat("qux", stringSize)+`"}`, outBatches[4], "GOT: %s", outBatches[4]) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) +} diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 452574b774..df720740af 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -250,7 +250,8 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.errors <- fmt.Errorf("failed to process snapshot: %w", err) return } - ctx, _ := stream.shutSig.SoftStopCtx(context.Background()) + ctx, done := stream.shutSig.SoftStopCtx(context.Background()) + defer done() if err := stream.startLr(ctx, lsnrestart); err != nil { stream.errors <- fmt.Errorf("failed to start logical replication: %w", err) return @@ -338,7 +339,8 @@ func (s *Stream) streamMessages(currentLSN LSN) error { lastEmittedCommitLSN := currentLSN commitLSN := func(force bool) error { - ctx, _ := s.shutSig.HardStopCtx(context.Background()) + ctx, done := s.shutSig.HardStopCtx(context.Background()) + defer done() ackedLSN := s.getAckedLSN() if ackedLSN == lastEmittedLSN { ackedLSN = lastEmittedCommitLSN @@ -358,7 +360,8 @@ func (s *Stream) streamMessages(currentLSN LSN) error { } }() - ctx, _ := s.shutSig.SoftStopCtx(context.Background()) + ctx, done := s.shutSig.SoftStopCtx(context.Background()) + defer done() for !s.shutSig.IsSoftStopSignalled() { if err := commitLSN(time.Now().After(s.nextStandbyMessageDeadline)); err != nil { return err @@ -388,7 +391,6 @@ func (s *Stream) streamMessages(currentLSN LSN) error { s.logger.Warn("received malformatted with no data") continue } - switch msg.Data[0] { case PrimaryKeepaliveMessageByteID: pkm, err := ParsePrimaryKeepaliveMessage(msg.Data[1:]) @@ -420,6 +422,8 @@ func (s *Stream) streamMessages(currentLSN LSN) error { lastEmittedLSN = msgLSN lastEmittedCommitLSN = msgLSN } + default: + return fmt.Errorf("unknown message type: %c", msg.Data[0]) } } // clean shutdown, return nil @@ -465,10 +469,13 @@ func (s *Stream) processChange(ctx context.Context, msgLSN LSN, xld XLogData, re } func (s *Stream) processSnapshot() error { - if err := s.snapshotter.prepare(); err != nil { + ctx, done := s.shutSig.SoftStopCtx(context.Background()) + defer done() + if err := s.snapshotter.prepare(ctx); err != nil { return fmt.Errorf("failed to prepare database snapshot - snapshot may be expired: %w", err) } defer func() { + s.logger.Debugf("Finished snapshot processing") if err := s.snapshotter.releaseSnapshot(); err != nil { s.logger.Warnf("Failed to release database snapshot: %v", err.Error()) } @@ -491,8 +498,6 @@ func (s *Stream) processSnapshot() error { offset = 0 ) - ctx, _ := s.shutSig.SoftStopCtx(context.Background()) - avgRowSizeBytes, err = s.snapshotter.findAvgRowSize(ctx, table) if err != nil { return fmt.Errorf("failed to calculate average row size for table %v: %w", table, err) @@ -549,7 +554,7 @@ func (s *Stream) processSnapshot() error { var rowsCount = 0 rowsStart := time.Now() totalScanDuration := time.Duration(0) - totalWaitingFromBenthos := time.Duration(0) + sendDuration := time.Duration(0) for snapshotRows.Next() { rowsCount += 1 @@ -596,14 +601,17 @@ func (s *Stream) processSnapshot() error { case <-s.shutSig.SoftStopChan(): return nil } - totalWaitingFromBenthos += time.Since(waitingFromBenthos) + sendDuration += time.Since(waitingFromBenthos) + } + if snapshotRows.Err() != nil { + return fmt.Errorf("failed to close snapshot data iterator for table %v: %w", table, snapshotRows.Err()) } batchEnd := time.Since(rowsStart) s.logger.Debugf("Batch duration: %v %s \n", batchEnd, tableName) s.logger.Debugf("Scan duration %v %s\n", totalScanDuration, tableName) - s.logger.Debugf("Waiting from benthos duration %v %s\n", totalWaitingFromBenthos, tableName) + s.logger.Debugf("Send duration %v %s\n", sendDuration, tableName) offset += batchSize @@ -672,7 +680,8 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, table TableFQN) (map[s func (s *Stream) Stop(ctx context.Context) error { s.shutSig.TriggerSoftStop() var wg errgroup.Group - stopNowCtx, _ := s.shutSig.HardStopCtx(ctx) + stopNowCtx, done := s.shutSig.HardStopCtx(ctx) + defer done() wg.Go(func() error { return s.pgConn.Close(stopNowCtx) }) diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 0da56e686f..ade46ce0f9 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -271,7 +271,7 @@ func CreateReplicationSlot( var snapshotResponse SnapshotCreationResponse if options.SnapshotAction == "export" { var err error - snapshotResponse, err = snapshotter.initSnapshotTransaction() + snapshotResponse, err = snapshotter.initSnapshotTransaction(ctx) if err != nil { return CreateReplicationSlotResult{}, err } diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go index 4bfbbd79e7..ef1461dd6d 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -74,10 +74,8 @@ func decodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM for idx, col := range logicalMsg.Tuple.Columns { colName := rel.Columns[idx].Name switch col.DataType { - case 'n': // null + case 'n', 'u': // null or unchanged toast values[colName] = nil - case 'u': // unchanged toast - // This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you. case 't': //text val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) if err != nil { @@ -104,13 +102,29 @@ func decodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM case 'n': // null values[colName] = nil case 'u': // unchanged toast - // This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you. + // In the case of an update of an unchanged toast value and the replica is set to + // IDENTITY FULL, we need to look at the old tuple in order to get the data, it's + // just marked as unchanged in the new tuple. + if idx < len(logicalMsg.OldTuple.Columns) { + col = logicalMsg.OldTuple.Columns[idx] + switch col.DataType { + case 'n', 'u': + values[colName] = nil + continue + case 't': + default: + return nil, fmt.Errorf("unable to decode column data, unknown data type: %d", col.DataType) + } + } + fallthrough case 't': //text val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) if err != nil { return nil, fmt.Errorf("unable to decode column data: %w", err) } values[colName] = val + default: + return nil, fmt.Errorf("unable to decode column data, unknown data type: %d", col.DataType) } } message.Data = values @@ -126,16 +140,15 @@ func decodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM for idx, col := range logicalMsg.OldTuple.Columns { colName := rel.Columns[idx].Name switch col.DataType { - case 'n': // null + case 'n', 'u': // null or unchanged toast values[colName] = nil - case 'u': // unchanged toast - // This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you. case 't': //text val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) if err != nil { return nil, fmt.Errorf("unable to decode column data: %w", err) } values[colName] = val + default: } } message.Data = values diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 37b8d0c7b2..85c642cca5 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -36,9 +36,12 @@ type SnapshotCreationResponse struct { // Therefore Snapshotter opens another connection to the database and sets the transaction to the snapshot. // This allows you to read the data that was in the database at the time of the snapshot creation. type Snapshotter struct { - pgConnection *sql.DB - snapshotCreateConnection *sql.DB - logger *service.Logger + pool *sql.DB + logger *service.Logger + // Only needed for older PG versions, holds the snapshot open for the reader + snapshotTxn *sql.Tx + // The TXN for the snapshot phase + readerTxn *sql.Tx snapshotName string @@ -52,41 +55,35 @@ func NewSnapshotter(dbDSN string, logger *service.Logger, version int) (*Snapsho return nil, err } - snapshotCreateConnection, err := openPgConnectionFromConfig(dbDSN) - if err != nil { - return nil, err - } - return &Snapshotter{ - pgConnection: pgConn, - snapshotCreateConnection: snapshotCreateConnection, - logger: logger, - version: version, + pool: pgConn, + logger: logger, + version: version, }, nil } -func (s *Snapshotter) initSnapshotTransaction() (SnapshotCreationResponse, error) { +func (s *Snapshotter) initSnapshotTransaction(ctx context.Context) (SnapshotCreationResponse, error) { if s.version > 14 { return SnapshotCreationResponse{}, errors.New("snapshot is exported by default for versions above PG14") } + if s.snapshotTxn != nil { + return SnapshotCreationResponse{}, errors.New("snapshot already exists") + } var snapshotName sql.NullString - snapshotRow, err := s.pgConnection.Query(`BEGIN; SELECT pg_export_snapshot();`) + tx, err := s.pool.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true, Isolation: sql.LevelRepeatableRead}) if err != nil { - return SnapshotCreationResponse{}, fmt.Errorf("cant get exported snapshot for initial streaming %w pg version: %d", err, s.version) + return SnapshotCreationResponse{}, fmt.Errorf("unable to begin a tx to export a snapshot: %w pg version: %d", err, s.version) } - + s.snapshotTxn = tx + snapshotRow := tx.QueryRowContext(ctx, `SELECT pg_export_snapshot();`) if snapshotRow.Err() != nil { - return SnapshotCreationResponse{}, fmt.Errorf("can get avg row size due to query failure: %w", snapshotRow.Err()) + return SnapshotCreationResponse{}, fmt.Errorf("unable to get snapshot name: %w", snapshotRow.Err()) } - if snapshotRow.Next() { - if err = snapshotRow.Scan(&snapshotName); err != nil { - return SnapshotCreationResponse{}, fmt.Errorf("cant scan snapshot name into string: %w", err) - } - } else { - return SnapshotCreationResponse{}, errors.New("cant get avg row size; 0 rows returned") + if err = snapshotRow.Scan(&snapshotName); err != nil { + return SnapshotCreationResponse{}, fmt.Errorf("cant scan snapshot name into string: %w", err) } return SnapshotCreationResponse{ExportedSnapshotName: snapshotName.String}, nil @@ -96,51 +93,37 @@ func (s *Snapshotter) setTransactionSnapshotName(snapshotName string) { s.snapshotName = snapshotName } -func (s *Snapshotter) prepare() error { +func (s *Snapshotter) prepare(ctx context.Context) error { if s.snapshotName == "" { return errors.New("snapshot name is not set") } - - if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil { - return err + if s.readerTxn != nil { + return errors.New("reader txn already open") } - + // Use a background context because we explicitly want the Tx to be long lived, we explicitly close it in the close method + tx, err := s.pool.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true, Isolation: sql.LevelRepeatableRead}) + if err != nil { + return fmt.Errorf("unable to start reader txn: %w", err) + } + s.readerTxn = tx sq, err := sanitize.SQLQuery("SET TRANSACTION SNAPSHOT $1;", s.snapshotName) if err != nil { return err } - - if _, err := s.pgConnection.Exec(sq); err != nil { + if _, err := tx.ExecContext(ctx, sq); err != nil { return err } - return nil } -func (s *Snapshotter) findAvgRowSize(ctx context.Context, table TableFQN) (sql.NullInt64, error) { - var ( - avgRowSize sql.NullInt64 - rows *sql.Rows - err error - ) - - // table is validated to be correct pg identifier, so we can use it directly - if rows, err = s.pgConnection.QueryContext(ctx, fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { - return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", err) +func (s *Snapshotter) findAvgRowSize(ctx context.Context, table TableFQN) (avgRowSize sql.NullInt64, err error) { + row := s.readerTxn.QueryRowContext(ctx, fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)) + if row.Err() != nil { + return avgRowSize, fmt.Errorf("cannot get avg row size due to query failure: %w", err) } - - if rows.Err() != nil { - return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", rows.Err()) + if err = row.Scan(&avgRowSize); err != nil { + return avgRowSize, fmt.Errorf("cannot get avg row size: %w", err) } - - if rows.Next() { - if err = rows.Scan(&avgRowSize); err != nil { - return avgRowSize, fmt.Errorf("can get avg row size: %w", err) - } - } else { - return avgRowSize, errors.New("can get avg row size; 0 rows returned") - } - return avgRowSize, nil } @@ -280,7 +263,6 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz } func (s *Snapshotter) querySnapshotData(ctx context.Context, table TableFQN, lastSeenPk map[string]any, pkColumns []string, limit int) (rows *sql.Rows, err error) { - s.logger.Debugf("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pkColumns) if lastSeenPk == nil { @@ -289,19 +271,16 @@ func (s *Snapshotter) querySnapshotData(ctx context.Context, table TableFQN, las if err != nil { return nil, err } - - return s.pgConnection.QueryContext(ctx, sq) + return s.readerTxn.QueryContext(ctx, sq) } var ( placeholders []string lastSeenPksValues []any - i = 1 ) - for _, col := range pkColumns { - placeholders = append(placeholders, fmt.Sprintf("$%d", i)) - i++ + for i, col := range pkColumns { + placeholders = append(placeholders, fmt.Sprintf("$%d", i+1)) lastSeenPksValues = append(lastSeenPksValues, lastSeenPk[col]) } @@ -314,27 +293,38 @@ func (s *Snapshotter) querySnapshotData(ctx context.Context, table TableFQN, las return nil, err } - return s.pgConnection.QueryContext(ctx, sq) + return s.readerTxn.QueryContext(ctx, sq) } func (s *Snapshotter) releaseSnapshot() error { - if s.version < 14 && s.snapshotCreateConnection != nil { - if _, err := s.snapshotCreateConnection.Exec("COMMIT;"); err != nil { + if s.version < 14 && s.snapshotTxn != nil { + if err := s.snapshotTxn.Commit(); err != nil { return err } + s.snapshotTxn = nil } - - _, err := s.pgConnection.Exec("COMMIT;") - return err + if err := s.readerTxn.Commit(); err != nil { + return err + } + s.readerTxn = nil + return nil } func (s *Snapshotter) closeConn() error { - if s.pgConnection != nil { - return s.pgConnection.Close() + if s.readerTxn != nil { + if err := s.readerTxn.Rollback(); err != nil { + return err + } + s.readerTxn = nil } - - if s.snapshotCreateConnection != nil { - return s.snapshotCreateConnection.Close() + if s.snapshotTxn != nil { + if err := s.snapshotTxn.Rollback(); err != nil { + return err + } + s.snapshotTxn = nil + } + if err := s.pool.Close(); err != nil { + return err } return nil