Skip to content

Commit

Permalink
Merge pull request #793 from companieshouse/sec-112-enable-add-csrf-m…
Browse files Browse the repository at this point in the history
…iddleware

Sec 112 enable add csrf middleware
  • Loading branch information
mattch1 authored Oct 21, 2024
2 parents dacf64e + 5ca198e commit 4b5d21a
Show file tree
Hide file tree
Showing 13 changed files with 102 additions and 3 deletions.
17 changes: 17 additions & 0 deletions src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ import { isPscQueryParameterValidationMiddleware } from "./middleware/is.psc.val
import { transactionIdValidationMiddleware } from "./middleware/transaction.id.validation.middleware";
import { submissionIdValidationMiddleware } from "./middleware/submission.id.validation.middleware";
import { commonTemplateVariablesMiddleware } from "./middleware/common.variables.middleware";
import { CsrfProtectionMiddleware } from "@companieshouse/web-security-node";
import { SessionStore } from "@companieshouse/node-session-handler";
import { CACHE_SERVER, COOKIE_NAME } from "./utils/properties";
import Redis from 'ioredis';

const app = express();
app.disable("x-powered-by");
Expand All @@ -24,6 +28,7 @@ const nunjucksEnv = nunjucks.configure([
"views",
"node_modules/govuk-frontend/",
"node_modules/govuk-frontend/components/",
"node_modules/@companieshouse"
], {
autoescape: true,
express: app,
Expand All @@ -44,6 +49,7 @@ app.set("view engine", "html");
// apply middleware
app.use(cookieParser());
app.use(serviceAvailabilityMiddleware);

// validation middleware for url and query params - comapny number covered by companyAuthenticationMiddleware
// These need to run before companyAuthenticationMiddleware as that can log out full url
// if auth value is invalid and url has invalid data in it
Expand All @@ -53,10 +59,21 @@ app.use(`*${urls.ACTIVE_SUBMISSION_BASE}`, transactionIdValidationMiddleware);
app.use(`*${urls.ACTIVE_SUBMISSION_BASE}`, submissionIdValidationMiddleware);

app.use(`${urls.CONFIRMATION_STATEMENT}*`, sessionMiddleware);

const userAuthRegex = new RegExp("^" + urls.CONFIRMATION_STATEMENT + "/.+");
app.use(userAuthRegex, authenticationMiddleware);
app.use(`${urls.CONFIRMATION_STATEMENT}${urls.COMPANY_AUTH_PROTECTED_BASE}`, companyAuthenticationMiddleware);

//csrf middleware
const sessionStore = new SessionStore(new Redis(`redis://${CACHE_SERVER}`));

const csrfProtectionMiddleware = CsrfProtectionMiddleware({
sessionStore,
enabled: false,
sessionCookieName: COOKIE_NAME
});
app.use(csrfProtectionMiddleware);

app.use(commonTemplateVariablesMiddleware);
// apply our default router to /confirmation-statement
app.use(urls.CONFIRMATION_STATEMENT, router);
Expand Down
18 changes: 17 additions & 1 deletion src/controllers/error.controller.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { NextFunction, Request, Response } from "express";
import { logger } from "../utils/logger";
import { Templates } from "../types/template.paths";
import { CsrfError } from "@companieshouse/web-security-node";

const pageNotFound = (req: Request, res: Response) => {
return res.status(404).render(Templates.ERROR_404, { templateName: Templates.ERROR_404 });
Expand All @@ -16,4 +17,19 @@ const errorHandler = (err: Error, req: Request, res: Response, _next: NextFuncti
res.status(500).render(Templates.SERVICE_OFFLINE_MID_JOURNEY, { templateName: Templates.SERVICE_OFFLINE_MID_JOURNEY });
};

export default [pageNotFound, errorHandler];
/**
* This handler catches any CSRF errors thrown within the application.
* If it is not a CSRF, the error is passed to the next error handler.
* If it is a CSRF error, it responds with a 403 forbidden status and renders the CSRF error.
*/
const csrfErrorHandler = (err: CsrfError | Error, req: Request, res: Response, next: NextFunction) => {
if (!(err instanceof CsrfError)) {
return next(err);
}
return res.status(403).render(Templates.CSRF_ERROR, {
templateName: Templates.CSRF_ERROR,
csrfErrors: true
});
};

export default [pageNotFound, csrfErrorHandler, errorHandler];
3 changes: 2 additions & 1 deletion src/types/template.paths.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ export enum Templates {
CONFIRM_EMAIL_ADDRESS = "tasks/confirm-email-address",
REGISTERED_OFFICE_ADDRESS = "tasks/registered-office-address",
REGISTER_LOCATIONS = "tasks/register-locations",
REVIEW = "review"
REVIEW = "review",
CSRF_ERROR = "csrf-error"
}
36 changes: 36 additions & 0 deletions test/controllers/error.controller.csrf.unit.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
jest.mock("../../src/controllers/confirm.company.controller");

import mocks from "../mocks/all.middleware.mock";
import request from "supertest";
import app from "../../src/app";
import * as pageUrls from "../../src/types/page.urls";
import { CsrfError } from "@companieshouse/web-security-node";
import * as confirmCompanyController from "../../src/controllers/confirm.company.controller";

const mockGetConfirmCompany = confirmCompanyController.get as jest.Mock;

const CSRF_TOKEN_ERROR = "CSRF token mismatch";
const CSRF_ERROR_PAGE_TEXT = "We have not been able to save the information you submitted on the previous screen.";
const CSRF_ERROR_PAGE_HEADING = "Sorry, something went wrong";

describe("ERROR controller", () => {

beforeEach(() => {
jest.clearAllMocks();
// clearing the mocks will initialise them for first use as well as between tests
mocks.mockAuthenticationMiddleware.mockClear();
mocks.mockServiceAvailabilityMiddleware.mockClear();
mocks.mockSessionMiddleware.mockClear();
});

describe("CSRF error page tests", () => {

test("Should render the CSRF error page", async () => {
mockGetConfirmCompany.mockImplementationOnce(() => { throw new CsrfError(CSRF_TOKEN_ERROR); });
const response = await request(app).get(pageUrls.CONFIRM_COMPANY_PATH);
expect(response.status).toEqual(403);
expect(response.text).toContain(CSRF_ERROR_PAGE_HEADING);
expect(response.text).toContain(CSRF_ERROR_PAGE_TEXT);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ describe("Confirm Email Address controller tests", () => {
mocks.mockAuthenticationMiddleware.mockClear();
mocks.mockServiceAvailabilityMiddleware.mockClear();
mocks.mockSessionMiddleware.mockClear();
mocks.mockCsrfMiddleware.mockClear();
});

it("Should navigate to the Confirm Email Address page", async () => {
Expand Down
2 changes: 2 additions & 0 deletions test/middleware/company.authentication.middleware.unit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ jest.mock("@companieshouse/web-security-node");
jest.mock("../../src/utils/logger");
jest.mock("../../src/validators/company.number.validator");

import mockCsrfProtectionMiddleware from "../mocks/csrf.middleware.mock";
import mockSessionMiddleware from "../mocks/session.middleware.mock";
import mockServiceAvailabilityMiddleware from "../mocks/service.availability.middleware.mock";
import mockAuthenticationMiddleware from "../mocks/authentication.middleware.mock";
Expand Down Expand Up @@ -48,6 +49,7 @@ describe("company authentication middleware tests", () => {
mockTransactionIdValidationMiddleware.mockClear();
mockSubmissionIdValidationMiddleware.mockClear();
mockLoggerErrorRequest.mockClear();
mockCsrfProtectionMiddleware.mockClear();
});

it("should call CH authentication library when company pattern in url", async () => {
Expand Down
2 changes: 2 additions & 0 deletions test/middleware/service.availability.middleware.unit.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
jest.mock("ioredis");
jest.mock("../../src/utils/feature.flag");

import mockCsrfProtectionMiddleware from "../mocks/csrf.middleware.mock";
import request from "supertest";
import app from "../../src/app";
import { isActiveFeature } from "../../src/utils/feature.flag";
Expand All @@ -11,6 +12,7 @@ describe("service availability middleware tests", () => {

beforeEach(() => {
jest.clearAllMocks();
mockCsrfProtectionMiddleware.mockClear();
});

it("should return service offline page", async () => {
Expand Down
2 changes: 2 additions & 0 deletions test/middleware/submission.id.validation.middleware.unit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ jest.mock("../../src/middleware/transaction.id.validation.middleware");
jest.mock("../../src/services/company.profile.service");
jest.mock("../../src/utils/logger");

import mockCsrfProtectionMiddleware from "../mocks/csrf.middleware.mock";
import mockServiceAvailabilityMiddleware from "../mocks/service.availability.middleware.mock";
import mockAuthenticationMiddleware from "../mocks//authentication.middleware.mock";
import mockSessionMiddleware from "../mocks/session.middleware.mock";
Expand Down Expand Up @@ -48,6 +49,7 @@ describe("Submission ID validation middleware tests", () => {
mockIsUrlIdValid.mockClear();
mockTransactionIdValidationMiddleware.mockClear();
mockLoggerErrorRequest.mockClear();
mockCsrfProtectionMiddleware.mockClear();
});

it("Should stop invalid submission id", async () => {
Expand Down
2 changes: 2 additions & 0 deletions test/middleware/transaction.id.validation.middleware.unit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ jest.mock("../../src/middleware/submission.id.validation.middleware");
jest.mock("../../src/services/company.profile.service");
jest.mock("../../src/utils/logger");

import mockCsrfProtectionMiddleware from "../mocks/csrf.middleware.mock";
import mockServiceAvailabilityMiddleware from "../mocks/service.availability.middleware.mock";
import mockAuthenticationMiddleware from "../mocks/authentication.middleware.mock";
import mockSessionMiddleware from "../mocks/session.middleware.mock";
Expand Down Expand Up @@ -47,6 +48,7 @@ describe("Transaction ID validation middleware tests", () => {
mockIsUrlIdValid.mockClear();
mockSubmissionIdValidationMiddleware.mockClear();
mockLoggerErrorRequest.mockClear();
mockCsrfProtectionMiddleware.mockClear();
});

it("Should stop invalid transaction id", async () => {
Expand Down
4 changes: 3 additions & 1 deletion test/mocks/all.middleware.mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import mockSubmissionIdValidationMiddleware from "./submission.id.validation.mid
import mockTransactionIdValidationMiddleware from "./transaction.id.validation.middleware.mock";
import mockIsPscQueryParameterValidationMiddleware from "./is.psc.validation.middleware.mock";
import mockCompanyNumberQueryParameterValidationMiddleware from "./company.number.validation.middleware.mock";
import mockCsrfMiddleware from "./csrf.middleware.mock";

export default {
mockServiceAvailabilityMiddleware,
Expand All @@ -15,5 +16,6 @@ export default {
mockSubmissionIdValidationMiddleware,
mockTransactionIdValidationMiddleware,
mockIsPscQueryParameterValidationMiddleware,
mockCompanyNumberQueryParameterValidationMiddleware
mockCompanyNumberQueryParameterValidationMiddleware,
mockCsrfMiddleware
};
9 changes: 9 additions & 0 deletions test/mocks/csrf.middleware.mock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { NextFunction, Request, Response } from "express";
import { CsrfProtectionMiddleware } from "@companieshouse/web-security-node";

jest.mock("@companieshouse/web-security-node");

const mockCsrfProtectionMiddleware = CsrfProtectionMiddleware as jest.Mock;
mockCsrfProtectionMiddleware.mockImplementation((_opts) => (req: Request, res: Response, next: NextFunction) => next());

export default mockCsrfProtectionMiddleware;
3 changes: 3 additions & 0 deletions test/routes/routes.spec.unit.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
jest.mock("ioredis");

import mockCsrfProtectionMiddleware from "../mocks/csrf.middleware.mock";
import request from "supertest";
import app from "../../src/app";

describe("Basic URL Tests", () => {

it("should find the accessibility statement page", async () => {
mockCsrfProtectionMiddleware.mockClear();

const response = await request(app)
.get("/confirmation-statement/accessibility-statement");

Expand Down
6 changes: 6 additions & 0 deletions views/csrf-error.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{% from "web-security-node/components/csrf-error/macro.njk" import csrfError %}
{% extends "layout.html" %}

{% block content %}
{{ csrfError({}) }}
{% endblock %}

0 comments on commit 4b5d21a

Please sign in to comment.