Skip to content

Commit

Permalink
Replace equality checks on sealed objects with is checks (fixes #129)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyay10 committed Apr 10, 2023
1 parent 12257ab commit 27998dd
Show file tree
Hide file tree
Showing 23 changed files with 621 additions and 440 deletions.
36 changes: 18 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ object AlphaSealedEnum : SealedEnum<Alpha> {
)

override fun ordinalOf(obj: Alpha): Int = when (obj) {
Alpha.Beta -> 0
Alpha.Gamma -> 1
is Alpha.Beta -> 0
is Alpha.Gamma -> 1
}

override fun nameOf(obj: AlphaSealedEnum): String = when (obj) {
Alpha.Beta -> "Alpha_Beta"
Alpha.Gamma -> "Alpha_Gamma"
is Alpha.Beta -> "Alpha_Beta"
is Alpha.Gamma -> "Alpha_Gamma"
}

override fun valueOf(name: String): AlphaSealedEnum = when (name) {
Expand Down Expand Up @@ -157,15 +157,15 @@ object AlphaLevelOrderSealedEnum : SealedEnum<Alpha> {
)

override fun ordinalOf(obj: Alpha): Int = when (obj) {
Alpha.Delta -> 0
Alpha.Beta.Gamma -> 1
Alpha.Epsilon.Zeta -> 2
is Alpha.Delta -> 0
is Alpha.Beta.Gamma -> 1
is Alpha.Epsilon.Zeta -> 2
}

override fun nameOf(obj: AlphaLevelOrderSealedEnum): String = when (obj) {
Alpha.Delta -> "Alpha_Delta"
Alpha.Beta.Gamma -> "Alpha_Beta_Gamma"
Alpha.Epsilon.Zeta -> "Alpha_Epsilon_Zeta"
is Alpha.Delta -> "Alpha_Delta"
is Alpha.Beta.Gamma -> "Alpha_Beta_Gamma"
is Alpha.Epsilon.Zeta -> "Alpha_Epsilon_Zeta"
}

override fun valueOf(name: String): AlphaLevelOrderSealedEnum = when (name) {
Expand All @@ -184,15 +184,15 @@ object AlphaInOrderSealedEnum : SealedEnum<Alpha> {
)

override fun ordinalOf(obj: Alpha): Int = when (obj) {
Alpha.Beta.Gamma -> 0
Alpha.Delta -> 1
Alpha.Epsilon.Zeta -> 2
is Alpha.Beta.Gamma -> 0
is Alpha.Delta -> 1
is Alpha.Epsilon.Zeta -> 2
}

override fun nameOf(obj: AlphaInOrderSealedEnum): String = when (obj) {
Alpha.Beta.Gamma -> "Alpha_Beta_Gamma"
Alpha.Delta -> "Alpha_Delta"
Alpha.Epsilon.Zeta -> "Alpha_Epsilon_Zeta"
is Alpha.Beta.Gamma -> "Alpha_Beta_Gamma"
is Alpha.Delta -> "Alpha_Delta"
is Alpha.Epsilon.Zeta -> "Alpha_Epsilon_Zeta"
}

override fun valueOf(name: String): AlphaInOrderSealedEnum = when (name) {
Expand Down Expand Up @@ -293,8 +293,8 @@ object AlphaSealedEnum : SealedEnum<Alpha>, SealedEnumWithEnumProvider<Alpha, Al
)

override fun sealedObjectToEnum(obj: Alpha): AlphaEnum = when (obj) {
Alpha.Beta -> AlphaEnum.Alpha_Beta
Alpha.Gamma -> AlphaEnum.Alpha_Gamma
is Alpha.Beta -> AlphaEnum.Alpha_Beta
is Alpha.Gamma -> AlphaEnum.Alpha_Gamma
}

override fun enumToSealedObject(enum: AlphaEnum): Alpha = when (enum) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ internal data class SealedEnumTypeSpec(
} else {
beginControlFlow("return when (obj)")
sealedObjects.forEachIndexed { index, obj ->
addStatement("%T -> $index", obj)
addStatement("is %T -> $index", obj)
}
endControlFlow()
}
Expand All @@ -131,7 +131,7 @@ internal data class SealedEnumTypeSpec(
} else {
beginControlFlow("return when (obj)")
sealedObjects.forEach { obj ->
addStatement("%T -> %S", obj, sealedObjectToName(obj))
addStatement("is %T -> %S", obj, sealedObjectToName(obj))
}
endControlFlow()
}
Expand Down Expand Up @@ -177,7 +177,7 @@ internal data class SealedEnumTypeSpec(
beginControlFlow("return when (obj)")
sealedObjects.forEach { obj ->
addStatement(
"%T -> %T",
"is %T -> %T",
obj,
enumForSealedEnum.nestedClass(obj.simpleNames.joinToString("_"))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ public object OneObjectSealedClassSealedEnum : SealedEnum<OneObjectSealedClass>,
get() = OneObjectSealedClassEnum::class
public override fun ordinalOf(obj: OneObjectSealedClass): Int = when (obj) {
OneObjectSealedClass.FirstObject -> 0
is OneObjectSealedClass.FirstObject -> 0
}
public override fun nameOf(obj: OneObjectSealedClass): String = when (obj) {
OneObjectSealedClass.FirstObject -> "OneObjectSealedClass_FirstObject"
is OneObjectSealedClass.FirstObject -> "OneObjectSealedClass_FirstObject"
}
public override fun valueOf(name: String): OneObjectSealedClass = when (name) {
Expand All @@ -70,7 +70,7 @@ public object OneObjectSealedClassSealedEnum : SealedEnum<OneObjectSealedClass>,
public override fun sealedObjectToEnum(obj: OneObjectSealedClass): OneObjectSealedClassEnum =
when (obj) {
OneObjectSealedClass.FirstObject ->
is OneObjectSealedClass.FirstObject ->
OneObjectSealedClassEnum.OneObjectSealedClass_FirstObject
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ public object OneObjectSealedInterfaceSealedEnum : SealedEnum<OneObjectSealedInt
get() = OneObjectSealedInterfaceEnum::class
public override fun ordinalOf(obj: OneObjectSealedInterface): Int = when (obj) {
OneObjectSealedInterface.FirstObject -> 0
is OneObjectSealedInterface.FirstObject -> 0
}
public override fun nameOf(obj: OneObjectSealedInterface): String = when (obj) {
OneObjectSealedInterface.FirstObject -> "OneObjectSealedInterface_FirstObject"
is OneObjectSealedInterface.FirstObject -> "OneObjectSealedInterface_FirstObject"
}
public override fun valueOf(name: String): OneObjectSealedInterface = when (name) {
Expand All @@ -70,7 +70,7 @@ public object OneObjectSealedInterfaceSealedEnum : SealedEnum<OneObjectSealedInt
public override fun sealedObjectToEnum(obj: OneObjectSealedInterface):
OneObjectSealedInterfaceEnum = when (obj) {
OneObjectSealedInterface.FirstObject ->
is OneObjectSealedInterface.FirstObject ->
OneObjectSealedInterfaceEnum.OneObjectSealedInterface_FirstObject
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ public object TwoObjectSealedClassSealedEnum : SealedEnum<TwoObjectSealedClass>,
get() = TwoObjectSealedClassEnum::class
public override fun ordinalOf(obj: TwoObjectSealedClass): Int = when (obj) {
TwoObjectSealedClass.FirstObject -> 0
TwoObjectSealedClass.SecondObject -> 1
is TwoObjectSealedClass.FirstObject -> 0
is TwoObjectSealedClass.SecondObject -> 1
}
public override fun nameOf(obj: TwoObjectSealedClass): String = when (obj) {
TwoObjectSealedClass.FirstObject -> "TwoObjectSealedClass_FirstObject"
TwoObjectSealedClass.SecondObject -> "TwoObjectSealedClass_SecondObject"
is TwoObjectSealedClass.FirstObject -> "TwoObjectSealedClass_FirstObject"
is TwoObjectSealedClass.SecondObject -> "TwoObjectSealedClass_SecondObject"
}
public override fun valueOf(name: String): TwoObjectSealedClass = when (name) {
Expand All @@ -77,9 +77,9 @@ public object TwoObjectSealedClassSealedEnum : SealedEnum<TwoObjectSealedClass>,
public override fun sealedObjectToEnum(obj: TwoObjectSealedClass): TwoObjectSealedClassEnum =
when (obj) {
TwoObjectSealedClass.FirstObject ->
is TwoObjectSealedClass.FirstObject ->
TwoObjectSealedClassEnum.TwoObjectSealedClass_FirstObject
TwoObjectSealedClass.SecondObject ->
is TwoObjectSealedClass.SecondObject ->
TwoObjectSealedClassEnum.TwoObjectSealedClass_SecondObject
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ public object TwoObjectSealedInterfaceSealedEnum : SealedEnum<TwoObjectSealedInt
get() = TwoObjectSealedInterfaceEnum::class
public override fun ordinalOf(obj: TwoObjectSealedInterface): Int = when (obj) {
TwoObjectSealedInterface.FirstObject -> 0
TwoObjectSealedInterface.SecondObject -> 1
is TwoObjectSealedInterface.FirstObject -> 0
is TwoObjectSealedInterface.SecondObject -> 1
}
public override fun nameOf(obj: TwoObjectSealedInterface): String = when (obj) {
TwoObjectSealedInterface.FirstObject -> "TwoObjectSealedInterface_FirstObject"
TwoObjectSealedInterface.SecondObject -> "TwoObjectSealedInterface_SecondObject"
is TwoObjectSealedInterface.FirstObject -> "TwoObjectSealedInterface_FirstObject"
is TwoObjectSealedInterface.SecondObject -> "TwoObjectSealedInterface_SecondObject"
}
public override fun valueOf(name: String): TwoObjectSealedInterface = when (name) {
Expand All @@ -77,9 +77,9 @@ public object TwoObjectSealedInterfaceSealedEnum : SealedEnum<TwoObjectSealedInt
public override fun sealedObjectToEnum(obj: TwoObjectSealedInterface):
TwoObjectSealedInterfaceEnum = when (obj) {
TwoObjectSealedInterface.FirstObject ->
is TwoObjectSealedInterface.FirstObject ->
TwoObjectSealedInterfaceEnum.TwoObjectSealedInterface_FirstObject
TwoObjectSealedInterface.SecondObject ->
is TwoObjectSealedInterface.SecondObject ->
TwoObjectSealedInterfaceEnum.TwoObjectSealedInterface_SecondObject
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package com.livefront.sealedenum.compilation.equality

import com.livefront.sealedenum.GenSealedEnum
import org.intellij.lang.annotations.Language

sealed class Flag {
val i: Int = 1 shl ordinal
object FirstFlag : Flag()

object SecondFlag : Flag()

@GenSealedEnum(generateEnum = true)
companion object
}

@Language("kotlin")
val flagGenerated = """
package com.livefront.sealedenum.compilation.equality
import com.livefront.sealedenum.EnumForSealedEnumProvider
import com.livefront.sealedenum.SealedEnum
import com.livefront.sealedenum.SealedEnumWithEnumProvider
import kotlin.Int
import kotlin.String
import kotlin.collections.List
import kotlin.reflect.KClass
/**
* An isomorphic enum for the sealed class [Flag]
*/
public enum class FlagEnum() {
Flag_FirstFlag,
Flag_SecondFlag,
}
/**
* The isomorphic [FlagEnum] for [this].
*/
public val Flag.`enum`: FlagEnum
get() = FlagSealedEnum.sealedObjectToEnum(this)
/**
* The isomorphic [Flag] for [this].
*/
public val FlagEnum.sealedObject: Flag
get() = FlagSealedEnum.enumToSealedObject(this)
/**
* An implementation of [SealedEnum] for the sealed class [Flag]
*/
public object FlagSealedEnum : SealedEnum<Flag>, SealedEnumWithEnumProvider<Flag, FlagEnum>,
EnumForSealedEnumProvider<Flag, FlagEnum> {
public override val values: List<Flag> = listOf(
Flag.FirstFlag,
Flag.SecondFlag
)
public override val enumClass: KClass<FlagEnum>
get() = FlagEnum::class
public override fun ordinalOf(obj: Flag): Int = when (obj) {
is Flag.FirstFlag -> 0
is Flag.SecondFlag -> 1
}
public override fun nameOf(obj: Flag): String = when (obj) {
is Flag.FirstFlag -> "Flag_FirstFlag"
is Flag.SecondFlag -> "Flag_SecondFlag"
}
public override fun valueOf(name: String): Flag = when (name) {
"Flag_FirstFlag" -> Flag.FirstFlag
"Flag_SecondFlag" -> Flag.SecondFlag
else -> throw IllegalArgumentException(""${'"'}No sealed enum constant ${'$'}name""${'"'})
}
public override fun sealedObjectToEnum(obj: Flag): FlagEnum = when (obj) {
is Flag.FirstFlag -> FlagEnum.Flag_FirstFlag
is Flag.SecondFlag -> FlagEnum.Flag_SecondFlag
}
public override fun enumToSealedObject(`enum`: FlagEnum): Flag = when (enum) {
FlagEnum.Flag_FirstFlag -> Flag.FirstFlag
FlagEnum.Flag_SecondFlag -> Flag.SecondFlag
}
}
/**
* The index of [this] in the values list.
*/
public val Flag.ordinal: Int
get() = FlagSealedEnum.ordinalOf(this)
/**
* The name of [this] for use with valueOf.
*/
public val Flag.name: String
get() = FlagSealedEnum.nameOf(this)
/**
* A list of all [Flag] objects.
*/
public val Flag.Companion.values: List<Flag>
get() = FlagSealedEnum.values
/**
* Returns an implementation of [SealedEnum] for the sealed class [Flag]
*/
public val Flag.Companion.sealedEnum: FlagSealedEnum
get() = FlagSealedEnum
/**
* Returns the [Flag] object for the given [name].
*
* If the given name doesn't correspond to any [Flag], an [IllegalArgumentException] will be thrown.
*/
public fun Flag.Companion.valueOf(name: String): Flag = FlagSealedEnum.valueOf(name)
""".trimIndent()
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.livefront.sealedenum.compilation.equality

import com.livefront.sealedenum.testing.assertCompiles
import com.livefront.sealedenum.testing.assertGeneratedFileMatches
import com.livefront.sealedenum.testing.compile
import com.livefront.sealedenum.testing.getCommonSourceFile
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test

class FlagTests {
@Test
fun `two objects sealed class`() {
assertEquals(
listOf(Flag.FirstFlag, Flag.SecondFlag),
FlagSealedEnum.values
)
}

@Test
fun `two enums for sealed class`() {
assertEquals(
listOf(
FlagEnum.Flag_FirstFlag,
FlagEnum.Flag_SecondFlag
),
enumValues<FlagEnum>().toList()
)
}

@Test
fun `two enums for sealed class with mapping`() {
assertEquals(
Flag.values.map(Flag::enum),
enumValues<FlagEnum>().toList()
)
}

@Test
fun `correct enum class`() {
assertEquals(FlagEnum::class, FlagSealedEnum.enumClass)
}

@Test
fun `compilation generates correct code`() {
val result = compile(getCommonSourceFile("compilation", "equality", "Flag.kt"))

assertCompiles(result)
assertGeneratedFileMatches("Flag_SealedEnum.kt", flagGenerated, result)
}
}
Loading

0 comments on commit 27998dd

Please sign in to comment.