Skip to content

Commit

Permalink
Merge pull request #1374 from znsio/discriminator_based_example_gen
Browse files Browse the repository at this point in the history
Support discriminator based example generation
  • Loading branch information
joelrosario authored Oct 23, 2024
2 parents 1e9c236 + 8c1d078 commit a9c1a1e
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 152 deletions.
32 changes: 32 additions & 0 deletions core/src/main/kotlin/io/specmatic/core/Feature.kt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,38 @@ data class Feature(
}
}

fun generateRequestResponses(scenario: Scenario): List<GeneratedRequestResponse> {
try {
val requests = scenario.generateHttpRequestV2()
val responses = scenario.generateHttpResponseV2(serverState)

val generatedRequestResponses = if(requests.size > responses.size) {
requests.map { (discriminator, request) ->
val response = if(responses.containsKey(discriminator)) responses.getValue(discriminator)
else responses.values.first()
GeneratedRequestResponse(request, response, discriminator)
}
} else {
responses.map { (discriminator, response) ->
val request = if(requests.containsKey(discriminator)) requests.getValue(discriminator)
else requests.values.first()
GeneratedRequestResponse(request, response, discriminator)
}
}

return generatedRequestResponses
} finally {
serverState = emptyMap()
}
}

// Better name
data class GeneratedRequestResponse(
val request: HttpRequest,
val response: HttpResponse,
val requestKind: String
)

fun stubResponse(
httpRequest: HttpRequest,
mismatchMessages: MismatchMessages = DefaultMismatchMessages
Expand Down
135 changes: 87 additions & 48 deletions core/src/main/kotlin/io/specmatic/core/HttpRequestPattern.kt
Original file line number Diff line number Diff line change
Expand Up @@ -439,70 +439,109 @@ data class HttpRequestPattern(
}

fun generate(resolver: Resolver): HttpRequest {
var newRequest = HttpRequest()

return attempt(breadCrumb = "REQUEST") {
if (method == null) {
throw missingParam("HTTP method")
}
if (httpPathPattern == null) {
throw missingParam("URL path")
}
newRequest = newRequest.updateMethod(method)
attempt(breadCrumb = "URL") {
newRequest = newRequest.updatePath(httpPathPattern.generate(resolver))
val queryParams = httpQueryParamPattern.generate(resolver)
for (queryParam in queryParams) {
newRequest = newRequest.updateQueryParam(queryParam.first, queryParam.second)
}
}
val headers = headersPattern.generate(resolver)
HttpRequest()
.updateMethod(method)
.generateAndUpdateURL(resolver)
.generateAndUpdateBody(resolver, body)
.generateAndUpdateHeaders(resolver)
.generateAndUpdateFormFieldsValues(resolver)
.generateAndUpdateSecuritySchemes(resolver)
.generateAndUpdateMultiPartData(resolver)
}
}

val body = body
attempt(breadCrumb = "BODY") {
resolver.withCyclePrevention(body) {cyclePreventedResolver ->
body.generate(cyclePreventedResolver).let { value ->
newRequest = newRequest.updateBody(value)
}
}
fun generateV2(resolver: Resolver): Map<String, HttpRequest> {
return attempt(breadCrumb = "REQUEST") {
if (method == null) {
throw missingParam("HTTP method")
}
val baseRequest = HttpRequest()
.updateMethod(method)
.generateAndUpdateURL(resolver)
.generateAndUpdateHeaders(resolver)
.generateAndUpdateFormFieldsValues(resolver)
.generateAndUpdateSecuritySchemes(resolver)
.generateAndUpdateMultiPartData(resolver)

generateDiscriminatorBasedValues(resolver, body).map { (discriminatorKey, generatedBody) ->
discriminatorKey to baseRequest.updateBody(generatedBody)
}.toMap()
}
}

newRequest = newRequest.copy(headers = headers)
private fun HttpRequest.generateAndUpdatePath(resolver: Resolver): HttpRequest {
if (httpPathPattern == null) {
throw missingParam("URL path")
}
return this.updatePath(httpPathPattern.generate(resolver))
}

val formFieldsValue = attempt(breadCrumb = "FORM FIELDS") {
formFieldsPattern.mapValues { (key, pattern) ->
attempt(breadCrumb = key) {
resolver.withCyclePrevention(pattern) { cyclePreventedResolver ->
cyclePreventedResolver.generate(key, pattern)
}.toString()
}
private fun HttpRequest.generateAndUpdateQueryParam(resolver: Resolver): HttpRequest {
val queryParams = httpQueryParamPattern.generate(resolver)
return queryParams.fold(this) { request, queryParam ->
request.updateQueryParam(queryParam.first, queryParam.second)
}
}

private fun HttpRequest.generateAndUpdateURL(resolver: Resolver): HttpRequest {
return attempt(breadCrumb = "URL") {
this.generateAndUpdatePath(resolver)
.generateAndUpdateQueryParam(resolver)
}
}

private fun HttpRequest.generateAndUpdateBody(resolver: Resolver, body: Pattern): HttpRequest {
return attempt(breadCrumb = "BODY") {
resolver.withCyclePrevention(body) {cyclePreventedResolver ->
body.generate(cyclePreventedResolver).let { value ->
this.updateBody(value)
}
}
newRequest = when (formFieldsValue.size) {
0 -> newRequest
else -> newRequest.copy(
formFields = formFieldsValue,
headers = newRequest.headers.plus(CONTENT_TYPE to "application/x-www-form-urlencoded")
)
}
}
}

newRequest = securitySchemes.fold(newRequest) { request, securityScheme ->
securityScheme.addTo(request, resolver)
}
private fun HttpRequest.generateAndUpdateHeaders(resolver: Resolver): HttpRequest {
return this.copy(headers = headersPattern.generate(resolver))
}

val multipartData = attempt(breadCrumb = "MULTIPART DATA") {
multiPartFormDataPattern.mapIndexed { index, multiPartFormDataPattern ->
attempt(breadCrumb = "[$index]") { multiPartFormDataPattern.generate(resolver) }
private fun HttpRequest.generateAndUpdateFormFieldsValues(resolver: Resolver): HttpRequest {
val formFieldsValue = attempt(breadCrumb = "FORM FIELDS") {
formFieldsPattern.mapValues { (key, pattern) ->
attempt(breadCrumb = key) {
resolver.withCyclePrevention(pattern) { cyclePreventedResolver ->
cyclePreventedResolver.generate(key, pattern)
}.toString()
}
}
when (multipartData.size) {
0 -> newRequest
else -> newRequest.copy(
multiPartFormData = multipartData,
headers = newRequest.headers.plus(CONTENT_TYPE to "multipart/form-data")
)
}
if(formFieldsValue.isEmpty()) return this
return this.copy(
formFields = formFieldsValue,
headers = this.headers.plus(CONTENT_TYPE to "application/x-www-form-urlencoded")
)
}

private fun HttpRequest.generateAndUpdateSecuritySchemes(resolver: Resolver): HttpRequest {
return securitySchemes.fold(this) { request, securityScheme ->
securityScheme.addTo(request, resolver)
}
}

private fun HttpRequest.generateAndUpdateMultiPartData(resolver: Resolver): HttpRequest {
val multipartData = attempt(breadCrumb = "MULTIPART DATA") {
multiPartFormDataPattern.mapIndexed { index, multiPartFormDataPattern ->
attempt(breadCrumb = "[$index]") { multiPartFormDataPattern.generate(resolver) }
}
}
if(multipartData.isEmpty()) return this
return this.copy(
multiPartFormData = multipartData,
headers = this.headers.plus(CONTENT_TYPE to "multipart/form-data")
)
}

fun newBasedOn(row: Row, initialResolver: Resolver, status: Int = 0): Sequence<ReturnValue<HttpRequestPattern>> {
Expand Down
35 changes: 35 additions & 0 deletions core/src/main/kotlin/io/specmatic/core/HttpResponsePattern.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.specmatic.core

import io.specmatic.core.pattern.*
import io.specmatic.core.value.JSONArrayValue
import io.specmatic.core.value.ListValue
import io.specmatic.core.value.StringValue
import io.specmatic.core.value.Value
import io.specmatic.stub.softCastValueToXML
Expand Down Expand Up @@ -28,6 +30,20 @@ data class HttpResponsePattern(
}
}

fun generateResponseV2(resolver: Resolver): Map<String, HttpResponse> {
return attempt(breadCrumb = "RESPONSE") {
generateDiscriminatorBasedValues(resolver, body).map { (discriminatorKey, value) ->
val generatedBody = softCastValueToXML(value)
val headers = headersPattern.generate(resolver).plus(SPECMATIC_RESULT_HEADER to "success").let { headers ->
if ((headers.containsKey("Content-Type").not() && generatedBody.httpContentType.isBlank().not()))
headers.plus("Content-Type" to generatedBody.httpContentType)
else headers
}
discriminatorKey to HttpResponse(status, headers, generatedBody)
}.toMap()
}
}

fun generateResponseWithAll(resolver: Resolver): HttpResponse {
return attempt(breadCrumb = "RESPONSE") {
val value = softCastValueToXML(body.generateWithAll(resolver))
Expand Down Expand Up @@ -181,6 +197,25 @@ data class HttpResponsePattern(
}
}

fun generateDiscriminatorBasedValues(resolver: Resolver, pattern: Pattern): Map<String, Value> {
return resolver.withCyclePrevention(pattern) { updatedResolver ->
val resolvedPattern = resolvedHop(pattern, updatedResolver)

if(resolvedPattern is ListPattern) {
val listValuePattern = resolvedHop(resolvedPattern.pattern, updatedResolver)
if(listValuePattern is AnyPattern && listValuePattern.isDiscriminatorPresent()) {
val values = listValuePattern.generateForEveryDiscriminatorValue(updatedResolver)
return@withCyclePrevention values.mapValues { JSONArrayValue(listOf(it.value)) }
}
}

if(resolvedPattern !is AnyPattern || resolvedPattern.isDiscriminatorPresent().not()) {
return@withCyclePrevention mapOf("" to resolvedPattern.generate(updatedResolver))
}
resolvedPattern.generateForEveryDiscriminatorValue(updatedResolver)
}
}

private val valueMismatchMessages = object : MismatchMessages {
override fun mismatchMessage(expected: String, actual: String): String {
return "Value mismatch: Expected $expected, got value $actual"
Expand Down
16 changes: 15 additions & 1 deletion core/src/main/kotlin/io/specmatic/core/Scenario.kt
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ data class Scenario(
httpResponsePattern.generateResponse(resolver.copy(factStore = CheckFacts(facts), context = requestContext))
}

fun generateHttpResponseV2(actualFacts: Map<String, Value>, requestContext: Context = NoContext): Map<String, HttpResponse> =
scenarioBreadCrumb(this) {
val facts = combineFacts(expectedFacts, actualFacts, resolver)

httpResponsePattern.generateResponseV2(resolver.copy(factStore = CheckFacts(facts), context = requestContext))
}

private fun combineFacts(
expected: Map<String, Value>,
actual: Map<String, Value>,
Expand Down Expand Up @@ -219,7 +226,14 @@ data class Scenario(
}

fun generateHttpRequest(flagsBased: FlagsBased = DefaultStrategies): HttpRequest =
scenarioBreadCrumb(this) { httpRequestPattern.generate(flagsBased.update(resolver.copy(factStore = CheckFacts(expectedFacts)))) }
scenarioBreadCrumb(this) {
httpRequestPattern.generate(flagsBased.update(resolver.copy(factStore = CheckFacts(expectedFacts))))
}

fun generateHttpRequestV2(flagsBased: FlagsBased = DefaultStrategies): Map<String, HttpRequest> =
scenarioBreadCrumb(this) {
httpRequestPattern.generateV2(flagsBased.update(resolver.copy(factStore = CheckFacts(expectedFacts))))
}

fun matches(httpRequest: HttpRequest, httpResponse: HttpResponse, mismatchMessages: MismatchMessages = DefaultMismatchMessages, unexpectedKeyCheck: UnexpectedKeyCheck? = null): Result {
val resolver = updatedResolver(mismatchMessages, unexpectedKeyCheck).copy(context = RequestContext(httpRequest))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ class ExamplesInteractiveServer(

return getExistingExampleFiles(scenario, examples).map {
ExamplePathInfo(it.first.absolutePath, false)
}.plus(generateExampleFile(contractFile, feature, scenario))
}.plus(generateExampleFiles(contractFile, feature, scenario))
}

data class ExamplePathInfo(val path: String, val created: Boolean)
Expand All @@ -442,6 +442,35 @@ class ExamplesInteractiveServer(
return ExamplePathInfo(file.absolutePath, true)
}


private fun generateExampleFiles(
contractFile: File,
feature: Feature,
scenario: Scenario,
): List<ExamplePathInfo> {
val examplesDir = getExamplesDirPath(contractFile)
if(!examplesDir.exists()) examplesDir.mkdirs()

val generatedRequestResponses = feature.generateRequestResponses(scenario).map {
it.copy(response = it.response.cleanup())
}

return generatedRequestResponses.map { (request, response, kind) ->
val scenarioStub = ScenarioStub(request, response)
val stubJSON = scenarioStub.toJSON()
val uniqueNameForApiOperation = uniqueNameForApiOperation(
scenarioStub.request,
"",
scenarioStub.response.status
) + if (kind.isNotEmpty()) "_$kind" else ""

val file = examplesDir.resolve("${uniqueNameForApiOperation}_${exampleFileNamePostFixCounter.incrementAndGet()}.json")
println("Writing to file: ${file.relativeTo(contractFile.canonicalFile.parentFile).path}")
file.writeText(stubJSON.toStringLiteral())
ExamplePathInfo(file.absolutePath, true)
}
}

fun validateSingleExample(contractFile: File, exampleFile: File): Result {
val feature = parseContractFileToFeature(contractFile)
return validateSingleExample(feature, exampleFile)
Expand Down
Loading

0 comments on commit a9c1a1e

Please sign in to comment.