Skip to content

Commit

Permalink
Enforce maximum payload size
Browse files Browse the repository at this point in the history
  • Loading branch information
cbrewster committed Oct 24, 2024
1 parent 577c1d7 commit 7df315e
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 15 deletions.
127 changes: 127 additions & 0 deletions __tests__/max-payload-size.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import {
afterEach,
assert,
beforeEach,
describe,
expect,
test,
vitest,
} from 'vitest';
import { createMockTransportNetwork } from '../testUtil/fixtures/mockTransport';
import {
Err,
Ok,
Procedure,
ServiceSchema,
createClient,
createServer,
} from '../router';
import { MAX_PAYLOAD_SIZE_EXCEEDED } from '../router/errors';
import { Type } from '@sinclair/typebox';
import { readNextResult } from '../testUtil';
import { MaxPayloadSizeExceeded } from '../transport/sessionStateMachine/common';

describe('client exceeded max payload size', () => {
let mockTransportNetwork: ReturnType<typeof createMockTransportNetwork>;

beforeEach(async () => {
mockTransportNetwork = createMockTransportNetwork({
client: { maxPayloadSizeBytes: 1024 },
});
});

afterEach(async () => {
await mockTransportNetwork.cleanup();
});

test('rpc init exceeds max payload size', async () => {
const mockHandler = vitest.fn();
const services = {
service: ServiceSchema.define({
echo: Procedure.rpc({
requestInit: Type.String(),
responseData: Type.String(),
handler: mockHandler,
}),
}),
};
createServer(mockTransportNetwork.getServerTransport(), services);
const client = createClient<typeof services>(
mockTransportNetwork.getClientTransport('client'),
'SERVER',
);

const result = await client.service.echo.rpc('0'.repeat(1025));
expect(result).toStrictEqual({
ok: false,
payload: {
code: MAX_PAYLOAD_SIZE_EXCEEDED,
message: 'payload exceeded maximum payload size size=1241 max=1024',
},
});
expect(mockHandler).not.toHaveBeenCalled();
});

test('stream message exceeds max payload size', async () => {
let handlerCanceled: Promise<null> | undefined;
const services = {
service: ServiceSchema.define({
echo: Procedure.stream({
requestInit: Type.String(),
requestData: Type.String(),
responseData: Type.String(),
responseError: Type.Object({
code: Type.Literal('ERROR'),
message: Type.String(),
}),
handler: async ({ ctx, reqInit, reqReadable, resWritable }) => {
handlerCanceled = new Promise((resolve) => {
ctx.signal.onabort = () => resolve(null);
});

resWritable.write(Ok(reqInit));
for await (const msg of reqReadable) {
if (msg.ok) {
resWritable.write(Ok(msg.payload));
} else {
resWritable.write(
Err({
code: 'ERROR',
message: 'error reading from client',
}),
);
break;
}
}
},
}),
}),
};
createServer(mockTransportNetwork.getServerTransport(), services);
const transport = mockTransportNetwork.getClientTransport('client');
const client = createClient<typeof services>(transport, 'SERVER');

const stream = client.service.echo.stream('start');
let result = await readNextResult(stream.resReadable);
expect(result).toStrictEqual({ ok: true, payload: 'start' });

let error;
try {
stream.reqWritable.write('0'.repeat(1025));
} catch (e) {
error = e;
}
expect(error).toBeInstanceOf(MaxPayloadSizeExceeded);

result = await readNextResult(stream.resReadable);
expect(result).toStrictEqual({
ok: false,
payload: {
code: MAX_PAYLOAD_SIZE_EXCEEDED,
message: 'payload exceeded maximum payload size size=1148 max=1024',
},
});
assert(handlerCanceled);
await handlerCanceled;
});
});
80 changes: 65 additions & 15 deletions router/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ import {
CANCEL_CODE,
ReaderErrorSchema,
UNEXPECTED_DISCONNECT_CODE,
MAX_PAYLOAD_SIZE_EXCEEDED,
} from './errors';
import { MaxPayloadSizeExceeded } from '../transport/sessionStateMachine/common';

const ReaderErrResultSchema = ErrResultSchema(ReaderErrorSchema);

Expand Down Expand Up @@ -297,11 +299,42 @@ function handleProc(
let cleanClose = true;
const reqWritable = new WritableImpl<Static<PayloadType>>({
writeCb: (rawIn) => {
sessionScopedSend({
streamId,
payload: rawIn,
controlFlags: 0,
});
try {
sessionScopedSend({
streamId,
payload: rawIn,
controlFlags: 0,
});
} catch (e) {
if (!(e instanceof MaxPayloadSizeExceeded)) {
throw e;
}

cleanClose = false;
if (!resReadable.isClosed()) {
resReadable._pushValue(
Err({
code: MAX_PAYLOAD_SIZE_EXCEEDED,
message: e.message,
}),
);
closeReadable();
}

reqWritable.close();
// TODO: Is this the right error to send to the server?
sessionScopedSend(
cancelMessage(
streamId,
Err({
code: CANCEL_CODE,
message: 'cancelled by client',
}),
),
);

throw e;
}
},
// close callback
closeCb: () => {
Expand Down Expand Up @@ -480,16 +513,33 @@ function handleProc(
transport.addEventListener('message', onMessage);
transport.addEventListener('sessionStatus', onSessionStatus);

sessionScopedSend({
streamId,
serviceName,
procedureName,
tracing: getPropagationContext(ctx),
payload: init,
controlFlags: procClosesWithInit
? ControlFlags.StreamOpenBit | ControlFlags.StreamClosedBit
: ControlFlags.StreamOpenBit,
});
try {
sessionScopedSend({
streamId,
serviceName,
procedureName,
tracing: getPropagationContext(ctx),
payload: init,
controlFlags: procClosesWithInit
? ControlFlags.StreamOpenBit | ControlFlags.StreamClosedBit
: ControlFlags.StreamOpenBit,
});
} catch (e) {
if (!(e instanceof MaxPayloadSizeExceeded)) {
throw e;
}

cleanClose = false;
resReadable._pushValue(
Err({
code: MAX_PAYLOAD_SIZE_EXCEEDED,
message: e.message,
}),
);
closeReadable();

reqWritable.close();
}

if (procClosesWithInit) {
reqWritable.close();
Expand Down
8 changes: 8 additions & 0 deletions router/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ export const INVALID_REQUEST_CODE = 'INVALID_REQUEST';
* {@link CANCEL_CODE} is the code used when either server or client cancels the stream.
*/
export const CANCEL_CODE = 'CANCEL';
/**
* {@link MAX_PAYLOAD_SIZE_EXCEEDED} is the code used when a request's payload exceeds the maximum allowed size.
*/
export const MAX_PAYLOAD_SIZE_EXCEEDED = 'MAX_PAYLOAD_SIZE_EXCEEDED';

type TLiteralString = TLiteral<string>;

Expand Down Expand Up @@ -72,6 +76,10 @@ export const ReaderErrorSchema = Type.Union([
code: Type.Literal(CANCEL_CODE),
message: Type.String(),
}),
Type.Object({
code: Type.Literal(MAX_PAYLOAD_SIZE_EXCEEDED),
message: Type.String(),
}),
]) satisfies ProcedureErrorSchemaType;

/**
Expand Down
1 change: 1 addition & 0 deletions transport/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export const defaultTransportOptions: TransportOptions = {
connectionTimeoutMs: 2_000,
handshakeTimeoutMs: 1_000,
enableTransparentSessionReconnects: true,
maxPayloadSizeBytes: 4 * 1024 * 1024,
codec: NaiveJsonCodec,
};

Expand Down
17 changes: 17 additions & 0 deletions transport/sessionStateMachine/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ export interface SessionOptions {
* The codec to use for encoding/decoding messages over the wire
*/
codec: Codec;
/**
* The maximum payload size that is allowed to be sent or received.
*/
maxPayloadSizeBytes: number;
}

// all session states have a from and options
Expand Down Expand Up @@ -209,6 +213,12 @@ export interface IdentifiedSessionProps extends CommonSessionProps {
protocolVersion: ProtocolVersion;
}

export class MaxPayloadSizeExceeded extends Error {
constructor(size: number, max: number) {
super(`payload exceeded maximum payload size size=${size} max=${max}`);
}
}

export abstract class IdentifiedSession extends CommonSession {
readonly id: SessionId;
readonly telemetry: TelemetryInfo;
Expand Down Expand Up @@ -276,6 +286,13 @@ export abstract class IdentifiedSession extends CommonSession {
data: this.options.codec.toBuffer(msg),
};

if (encodedMsg.data.byteLength > this.options.maxPayloadSizeBytes) {
throw new MaxPayloadSizeExceeded(
encodedMsg.data.byteLength,
this.options.maxPayloadSizeBytes,
);
}

this.seq++;

return encodedMsg;
Expand Down

0 comments on commit 7df315e

Please sign in to comment.