Skip to content

Commit

Permalink
Order functions, methods and properties in a class by Python's conven…
Browse files Browse the repository at this point in the history
…tional order
  • Loading branch information
gumyr committed Jan 4, 2025
1 parent b539663 commit 93513b1
Showing 1 changed file with 206 additions and 3 deletions.
209 changes: 206 additions & 3 deletions tools/refactor_topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,188 @@
}


def sort_class_methods_by_convention(class_def: cst.ClassDef) -> cst.ClassDef:
"""Sort methods and properties in a class according to Python conventions."""
methods, properties = extract_methods_and_properties(class_def)
sorted_body = order_methods_by_convention(methods, properties)

other_statements = [
stmt for stmt in class_def.body.body if not isinstance(stmt, cst.FunctionDef)
]
final_body = cst.IndentedBlock(body=other_statements + sorted_body)
return class_def.with_changes(body=final_body)


def extract_methods_and_properties(
class_def: cst.ClassDef,
) -> tuple[List[cst.FunctionDef], List[List[cst.FunctionDef]]]:
"""
Extract methods and properties (with setters grouped together) from a class.
Returns:
- methods: Regular methods in the class.
- properties: List of grouped properties, where each group contains a getter
and its associated setter, if present.
"""
methods = []
properties = {}

for stmt in class_def.body.body:
if isinstance(stmt, cst.FunctionDef):
for decorator in stmt.decorators:
# Handle @property
if (
isinstance(decorator.decorator, cst.Name)
and decorator.decorator.value == "property"
):
properties[stmt.name.value] = [stmt] # Initialize with getter
# Handle @property.setter
elif (
isinstance(decorator.decorator, cst.Attribute)
and decorator.decorator.attr.value == "setter"
):
base_name = decorator.decorator.value.value # Extract base name
if base_name in properties:
properties[base_name].append(
stmt
) # Add setter to the property group
else:
# Setter appears before the getter
properties[base_name] = [None, stmt]

# Add non-property methods
if not any(
isinstance(decorator.decorator, cst.Name)
and decorator.decorator.value == "property"
or isinstance(decorator.decorator, cst.Attribute)
and decorator.decorator.attr.value == "setter"
for decorator in stmt.decorators
):
methods.append(stmt)

# Convert property dictionary into a sorted list of grouped properties
sorted_properties = [group for _, group in sorted(properties.items())]

return methods, sorted_properties


def order_methods_by_convention(
methods: List[cst.FunctionDef], properties: List[List[cst.FunctionDef]]
) -> List[cst.BaseStatement]:
"""
Order methods and properties in a class by Python's conventional order with section headers.
Sections:
- Constructor
- Properties (grouped by getter and setter)
- Class Methods
- Static Methods
- Public and Private Instance Methods
"""

def method_key(method: cst.FunctionDef) -> tuple[int, str]:
name = method.name.value
decorators = {
decorator.decorator.value
for decorator in method.decorators
if isinstance(decorator.decorator, cst.Name)
}

if name == "__init__":
return (0, name) # Constructor always comes first
elif name.startswith("__") and name.endswith("__"):
return (1, name) # Dunder methods follow
elif any(
decorator == "property" or decorator.endswith(".setter")
for decorator in decorators
):
return (2, name) # Properties and setters follow dunder methods
elif "classmethod" in decorators:
return (3, name) # Class methods follow properties
elif "staticmethod" in decorators:
return (4, name) # Static methods follow class methods
elif not name.startswith("_"):
return (5, name) # Public instance methods
else:
return (6, name) # Private methods last

# Flatten properties into a single sorted list
flattened_properties = [
prop for group in properties for prop in group if prop is not None
]

# Separate __init__, class methods, static methods, and instance methods
init_methods = [m for m in methods if m.name.value == "__init__"]
class_methods = [
m
for m in methods
if any(decorator.decorator.value == "classmethod" for decorator in m.decorators)
]
static_methods = [
m
for m in methods
if any(
decorator.decorator.value == "staticmethod" for decorator in m.decorators
)
]
instance_methods = [
m
for m in methods
if m.name.value != "__init__"
and not any(
decorator.decorator.value in {"classmethod", "staticmethod"}
for decorator in m.decorators
)
]

# Sort properties and each method group alphabetically
sorted_properties = sorted(flattened_properties, key=lambda prop: prop.name.value)
sorted_class_methods = sorted(class_methods, key=lambda m: m.name.value)
sorted_static_methods = sorted(static_methods, key=lambda m: m.name.value)
sorted_instance_methods = sorted(instance_methods, key=lambda m: method_key(m))

# Combine all sections with headers
ordered_sections: List[cst.BaseStatement] = []

if init_methods:
ordered_sections.append(
cst.SimpleStatementLine([cst.Expr(cst.Comment("# ---- Constructor ----"))])
)
ordered_sections.extend(init_methods)

if sorted_properties:
ordered_sections.append(
cst.SimpleStatementLine([cst.Expr(cst.Comment("# ---- Properties ----"))])
)
ordered_sections.extend(sorted_properties)

if sorted_class_methods:
ordered_sections.append(
cst.SimpleStatementLine(
[cst.Expr(cst.Comment("# ---- Class Methods ----"))]
)
)
ordered_sections.extend(sorted_class_methods)

if sorted_static_methods:
ordered_sections.append(
cst.SimpleStatementLine(
[cst.Expr(cst.Comment("# ---- Static Methods ----"))]
)
)
ordered_sections.extend(sorted_static_methods)

if sorted_instance_methods:
ordered_sections.append(
cst.SimpleStatementLine(
[cst.Expr(cst.Comment("# ---- Instance Methods ----"))]
)
)
ordered_sections.extend(sorted_instance_methods)

return ordered_sections


class ImportCollector(cst.CSTVisitor):
def __init__(self):
self.imports: Set[str] = set()
Expand All @@ -259,6 +441,22 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.extracted_classes[node.name.value] = node


class ClassMethodExtractor(cst.CSTVisitor):
def __init__(self):
self.class_methods: Dict[str, List[cst.FunctionDef]] = {}

def visit_ClassDef(self, node: cst.ClassDef) -> None:
class_name = node.name.value
self.class_methods[class_name] = []

for statement in node.body.body:
if isinstance(statement, cst.FunctionDef):
self.class_methods[class_name].append(statement)

# Sort methods alphabetically by name
self.class_methods[class_name].sort(key=lambda method: method.name.value)


class MixinClassExtractor(cst.CSTVisitor):
def __init__(self):
self.extracted_classes: Dict[str, cst.ClassDef] = {}
Expand All @@ -285,6 +483,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
if self.current_scope_level == 0:
self.functions.append(node)

def get_sorted_functions(self) -> List[cst.FunctionDef]:
return sorted(self.functions, key=lambda func: func.name.value)


class GlobalVariableExtractor(cst.CSTVisitor):
def __init__(self):
Expand Down Expand Up @@ -402,6 +603,7 @@ def write_topo_class_files(
}

for group_name, class_names in class_groups.items():

module_docstring = f"""
build123d topology
Expand Down Expand Up @@ -442,9 +644,10 @@ def write_topo_class_files(
source_tree.visit(variable_collector)

group_classes = [
extracted_classes[name] for name in class_names if name in extracted_classes
sort_class_methods_by_convention(extracted_classes[name])
for name in class_names
if name in extracted_classes
]

# Add imports for base classes based on layer dependencies
additional_imports = []
if group_name != "shape_core":
Expand Down Expand Up @@ -535,7 +738,7 @@ def write_topo_class_files(
body.append(var)
body.append(cst.EmptyLine(indent=False))

for func in function_collector.functions:
for func in function_collector.get_sorted_functions():
if func.name.value in function_source[group_name]:
body.append(func)
class_module = cst.Module(body=body, header=header)
Expand Down

0 comments on commit 93513b1

Please sign in to comment.