Skip to content

Commit

Permalink
Add GenSealedEnum.useIsChecksForSealedObjectComparison (fixes #129)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyay10 committed Apr 10, 2023
1 parent 12257ab commit d794e89
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.livefront.sealedenum.TreeTraversalOrder
public data class GenSealedEnumHolder(
val traversalOrder: TreeTraversalOrder,
val generateEnum: Boolean,
val useIsChecksForSealedObjectComparison: Boolean,
) {

public companion object {
Expand All @@ -34,9 +35,14 @@ public data class GenSealedEnumHolder(
it.name?.asString() == "generateEnum"
}?.value.toString().toBoolean()

val useIsChecksForSealedObjectComparison = ksAnnotation.arguments.find {
it.name?.asString() == "useIsChecksForSealedObjectComparison"
}?.value.toString().toBoolean()

return GenSealedEnumHolder(
traversalOrder,
generateEnum
generateEnum,
useIsChecksForSealedObjectComparison
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ internal class SealedEnumProcessor(
sealedEnumOptions = sealedEnumAnnotations.associate {
it.traversalOrder to if (it.generateEnum) {
@Suppress("UnsafeCallOnNullableType") // Guaranteed safe by above any
SealedEnumWithEnum(sealedClassInterfaces!!)
SealedEnumWithEnum(sealedClassInterfaces!!, it.useIsChecksForSealedObjectComparison)
} else {
SealedEnumOnly
SealedEnumOnly(it.useIsChecksForSealedObjectComparison)
}
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ public data class SealedEnumFileSpec(
parameterizedSealedClass = parameterizedSealedClass,
sealedClassCompanionObjectElement = sealedClassCompanionObjectElement,
sealedObjects = sealedObjects,
enumPrefix = enumPrefix
enumPrefix = enumPrefix,
useIsChecksForSealedObjectComparison = sealedEnumOption.useIsChecksForSealedObjectComparison
)

val sealedEnumClassName = ClassName(sealedClass.packageName, sealedEnumTypeSpecBuilder.name)
Expand Down Expand Up @@ -243,16 +244,21 @@ public data class SealedEnumFileSpec(
* The options for generating classes for a [TreeTraversalOrder].
*/
public sealed class SealedEnumOption {
public abstract val useIsChecksForSealedObjectComparison: Boolean

/**
* Generate the [SealedEnum] only.
*/
public object SealedEnumOnly : SealedEnumOption()
public data class SealedEnumOnly(
override val useIsChecksForSealedObjectComparison: Boolean
) : SealedEnumOption()

/**
* Generate the [SealedEnum] and a isomorphic enum class.
*/
public data class SealedEnumWithEnum(
val sealedClassInterfaces: List<TypeName>
val sealedClassInterfaces: List<TypeName>,
override val useIsChecksForSealedObjectComparison: Boolean
) : SealedEnumOption()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ internal data class SealedEnumTypeSpec(
private val parameterizedSealedClass: TypeName,
private val sealedClassCompanionObjectElement: TypeElement?,
private val sealedObjects: List<SealedObject>,
private val enumPrefix: String
private val enumPrefix: String,
private val useIsChecksForSealedObjectComparison: Boolean
) {
val name = sealedClass.createSealedEnumName(enumPrefix)
private val listOfSealedClass = List::class.asClassName().parameterizedBy(parameterizedSealedClass)
Expand Down Expand Up @@ -108,7 +109,7 @@ internal data class SealedEnumTypeSpec(
} else {
beginControlFlow("return when (obj)")
sealedObjects.forEachIndexed { index, obj ->
addStatement("%T -> $index", obj)
addStatement(if (useIsChecksForSealedObjectComparison) "is %T -> $index" else "%T -> $index", obj)
}
endControlFlow()
}
Expand All @@ -131,7 +132,11 @@ internal data class SealedEnumTypeSpec(
} else {
beginControlFlow("return when (obj)")
sealedObjects.forEach { obj ->
addStatement("%T -> %S", obj, sealedObjectToName(obj))
addStatement(
if (useIsChecksForSealedObjectComparison) "is %T -> %S" else "%T -> %S",
obj,
sealedObjectToName(obj)
)
}
endControlFlow()
}
Expand Down Expand Up @@ -177,7 +182,7 @@ internal data class SealedEnumTypeSpec(
beginControlFlow("return when (obj)")
sealedObjects.forEach { obj ->
addStatement(
"%T -> %T",
if (useIsChecksForSealedObjectComparison) "is %T -> %T" else "%T -> %T",
obj,
enumForSealedEnum.nestedClass(obj.simpleNames.joinToString("_"))
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package com.livefront.sealedenum.compilation.equality

import com.livefront.sealedenum.GenSealedEnum
import org.intellij.lang.annotations.Language

sealed class Flag {
val i: Int = 1 shl ordinal
object FirstFlag : Flag()

object SecondFlag : Flag()

@GenSealedEnum(generateEnum = true, useIsChecksForSealedObjectComparison = true)
companion object
}

@Language("kotlin")
val flagGenerated = """
package com.livefront.sealedenum.compilation.equality
import com.livefront.sealedenum.EnumForSealedEnumProvider
import com.livefront.sealedenum.SealedEnum
import com.livefront.sealedenum.SealedEnumWithEnumProvider
import kotlin.Int
import kotlin.String
import kotlin.collections.List
import kotlin.reflect.KClass
/**
* An isomorphic enum for the sealed class [Flag]
*/
public enum class FlagEnum() {
Flag_FirstFlag,
Flag_SecondFlag,
}
/**
* The isomorphic [FlagEnum] for [this].
*/
public val Flag.`enum`: FlagEnum
get() = FlagSealedEnum.sealedObjectToEnum(this)
/**
* The isomorphic [Flag] for [this].
*/
public val FlagEnum.sealedObject: Flag
get() = FlagSealedEnum.enumToSealedObject(this)
/**
* An implementation of [SealedEnum] for the sealed class [Flag]
*/
public object FlagSealedEnum : SealedEnum<Flag>, SealedEnumWithEnumProvider<Flag, FlagEnum>,
EnumForSealedEnumProvider<Flag, FlagEnum> {
public override val values: List<Flag> = listOf(
Flag.FirstFlag,
Flag.SecondFlag
)
public override val enumClass: KClass<FlagEnum>
get() = FlagEnum::class
public override fun ordinalOf(obj: Flag): Int = when (obj) {
is Flag.FirstFlag -> 0
is Flag.SecondFlag -> 1
}
public override fun nameOf(obj: Flag): String = when (obj) {
is Flag.FirstFlag -> "Flag_FirstFlag"
is Flag.SecondFlag -> "Flag_SecondFlag"
}
public override fun valueOf(name: String): Flag = when (name) {
"Flag_FirstFlag" -> Flag.FirstFlag
"Flag_SecondFlag" -> Flag.SecondFlag
else -> throw IllegalArgumentException(""${'"'}No sealed enum constant ${'$'}name""${'"'})
}
public override fun sealedObjectToEnum(obj: Flag): FlagEnum = when (obj) {
is Flag.FirstFlag -> FlagEnum.Flag_FirstFlag
is Flag.SecondFlag -> FlagEnum.Flag_SecondFlag
}
public override fun enumToSealedObject(`enum`: FlagEnum): Flag = when (enum) {
FlagEnum.Flag_FirstFlag -> Flag.FirstFlag
FlagEnum.Flag_SecondFlag -> Flag.SecondFlag
}
}
/**
* The index of [this] in the values list.
*/
public val Flag.ordinal: Int
get() = FlagSealedEnum.ordinalOf(this)
/**
* The name of [this] for use with valueOf.
*/
public val Flag.name: String
get() = FlagSealedEnum.nameOf(this)
/**
* A list of all [Flag] objects.
*/
public val Flag.Companion.values: List<Flag>
get() = FlagSealedEnum.values
/**
* Returns an implementation of [SealedEnum] for the sealed class [Flag]
*/
public val Flag.Companion.sealedEnum: FlagSealedEnum
get() = FlagSealedEnum
/**
* Returns the [Flag] object for the given [name].
*
* If the given name doesn't correspond to any [Flag], an [IllegalArgumentException] will be thrown.
*/
public fun Flag.Companion.valueOf(name: String): Flag = FlagSealedEnum.valueOf(name)
""".trimIndent()
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.livefront.sealedenum.compilation.equality

import com.livefront.sealedenum.testing.assertCompiles
import com.livefront.sealedenum.testing.assertGeneratedFileMatches
import com.livefront.sealedenum.testing.compile
import com.livefront.sealedenum.testing.getCommonSourceFile
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test

class FlagTests {
@Test
fun `two objects sealed class`() {
assertEquals(
listOf(Flag.FirstFlag, Flag.SecondFlag),
FlagSealedEnum.values
)
}

@Test
fun `two enums for sealed class`() {
assertEquals(
listOf(
FlagEnum.Flag_FirstFlag,
FlagEnum.Flag_SecondFlag
),
enumValues<FlagEnum>().toList()
)
}

@Test
fun `two enums for sealed class with mapping`() {
assertEquals(
Flag.values.map(Flag::enum),
enumValues<FlagEnum>().toList()
)
}

@Test
fun `correct enum class`() {
assertEquals(FlagEnum::class, FlagSealedEnum.enumClass)
}

@Test
fun `compilation generates correct code`() {
val result = compile(getCommonSourceFile("compilation", "equality", "Flag.kt"))

assertCompiles(result)
assertGeneratedFileMatches("Flag_SealedEnum.kt", flagGenerated, result)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ public class SealedEnumProcessor : AbstractProcessor() {
sealedEnumOptions = sealedEnumAnnotations.associate {
it.traversalOrder to if (it.generateEnum) {
@Suppress("UnsafeCallOnNullableType") // Guaranteed safe by above any call
SealedEnumWithEnum(sealedClassInterfaces!!)
SealedEnumWithEnum(sealedClassInterfaces!!, it.useIsChecksForSealedObjectComparison)
} else {
SealedEnumOnly
SealedEnumOnly(it.useIsChecksForSealedObjectComparison)
}
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ package com.livefront.sealedenum
@Repeatable
public annotation class GenSealedEnum(
val traversalOrder: TreeTraversalOrder = TreeTraversalOrder.IN_ORDER,
val generateEnum: Boolean = false
val generateEnum: Boolean = false,
val useIsChecksForSealedObjectComparison: Boolean = false
)

/**
Expand Down

0 comments on commit d794e89

Please sign in to comment.