diff --git a/tools/refactor_topo.py b/tools/refactor_topo.py index 839fe8f7..8aadbb1d 100644 --- a/tools/refactor_topo.py +++ b/tools/refactor_topo.py @@ -233,6 +233,188 @@ } +def sort_class_methods_by_convention(class_def: cst.ClassDef) -> cst.ClassDef: + """Sort methods and properties in a class according to Python conventions.""" + methods, properties = extract_methods_and_properties(class_def) + sorted_body = order_methods_by_convention(methods, properties) + + other_statements = [ + stmt for stmt in class_def.body.body if not isinstance(stmt, cst.FunctionDef) + ] + final_body = cst.IndentedBlock(body=other_statements + sorted_body) + return class_def.with_changes(body=final_body) + + +def extract_methods_and_properties( + class_def: cst.ClassDef, +) -> tuple[List[cst.FunctionDef], List[List[cst.FunctionDef]]]: + """ + Extract methods and properties (with setters grouped together) from a class. + + Returns: + - methods: Regular methods in the class. + - properties: List of grouped properties, where each group contains a getter + and its associated setter, if present. + """ + methods = [] + properties = {} + + for stmt in class_def.body.body: + if isinstance(stmt, cst.FunctionDef): + for decorator in stmt.decorators: + # Handle @property + if ( + isinstance(decorator.decorator, cst.Name) + and decorator.decorator.value == "property" + ): + properties[stmt.name.value] = [stmt] # Initialize with getter + # Handle @property.setter + elif ( + isinstance(decorator.decorator, cst.Attribute) + and decorator.decorator.attr.value == "setter" + ): + base_name = decorator.decorator.value.value # Extract base name + if base_name in properties: + properties[base_name].append( + stmt + ) # Add setter to the property group + else: + # Setter appears before the getter + properties[base_name] = [None, stmt] + + # Add non-property methods + if not any( + isinstance(decorator.decorator, cst.Name) + and decorator.decorator.value == "property" + or isinstance(decorator.decorator, cst.Attribute) + and decorator.decorator.attr.value == "setter" + for decorator in stmt.decorators + ): + methods.append(stmt) + + # Convert property dictionary into a sorted list of grouped properties + sorted_properties = [group for _, group in sorted(properties.items())] + + return methods, sorted_properties + + +def order_methods_by_convention( + methods: List[cst.FunctionDef], properties: List[List[cst.FunctionDef]] +) -> List[cst.BaseStatement]: + """ + Order methods and properties in a class by Python's conventional order with section headers. + + Sections: + - Constructor + - Properties (grouped by getter and setter) + - Class Methods + - Static Methods + - Public and Private Instance Methods + """ + + def method_key(method: cst.FunctionDef) -> tuple[int, str]: + name = method.name.value + decorators = { + decorator.decorator.value + for decorator in method.decorators + if isinstance(decorator.decorator, cst.Name) + } + + if name == "__init__": + return (0, name) # Constructor always comes first + elif name.startswith("__") and name.endswith("__"): + return (1, name) # Dunder methods follow + elif any( + decorator == "property" or decorator.endswith(".setter") + for decorator in decorators + ): + return (2, name) # Properties and setters follow dunder methods + elif "classmethod" in decorators: + return (3, name) # Class methods follow properties + elif "staticmethod" in decorators: + return (4, name) # Static methods follow class methods + elif not name.startswith("_"): + return (5, name) # Public instance methods + else: + return (6, name) # Private methods last + + # Flatten properties into a single sorted list + flattened_properties = [ + prop for group in properties for prop in group if prop is not None + ] + + # Separate __init__, class methods, static methods, and instance methods + init_methods = [m for m in methods if m.name.value == "__init__"] + class_methods = [ + m + for m in methods + if any(decorator.decorator.value == "classmethod" for decorator in m.decorators) + ] + static_methods = [ + m + for m in methods + if any( + decorator.decorator.value == "staticmethod" for decorator in m.decorators + ) + ] + instance_methods = [ + m + for m in methods + if m.name.value != "__init__" + and not any( + decorator.decorator.value in {"classmethod", "staticmethod"} + for decorator in m.decorators + ) + ] + + # Sort properties and each method group alphabetically + sorted_properties = sorted(flattened_properties, key=lambda prop: prop.name.value) + sorted_class_methods = sorted(class_methods, key=lambda m: m.name.value) + sorted_static_methods = sorted(static_methods, key=lambda m: m.name.value) + sorted_instance_methods = sorted(instance_methods, key=lambda m: method_key(m)) + + # Combine all sections with headers + ordered_sections: List[cst.BaseStatement] = [] + + if init_methods: + ordered_sections.append( + cst.SimpleStatementLine([cst.Expr(cst.Comment("# ---- Constructor ----"))]) + ) + ordered_sections.extend(init_methods) + + if sorted_properties: + ordered_sections.append( + cst.SimpleStatementLine([cst.Expr(cst.Comment("# ---- Properties ----"))]) + ) + ordered_sections.extend(sorted_properties) + + if sorted_class_methods: + ordered_sections.append( + cst.SimpleStatementLine( + [cst.Expr(cst.Comment("# ---- Class Methods ----"))] + ) + ) + ordered_sections.extend(sorted_class_methods) + + if sorted_static_methods: + ordered_sections.append( + cst.SimpleStatementLine( + [cst.Expr(cst.Comment("# ---- Static Methods ----"))] + ) + ) + ordered_sections.extend(sorted_static_methods) + + if sorted_instance_methods: + ordered_sections.append( + cst.SimpleStatementLine( + [cst.Expr(cst.Comment("# ---- Instance Methods ----"))] + ) + ) + ordered_sections.extend(sorted_instance_methods) + + return ordered_sections + + class ImportCollector(cst.CSTVisitor): def __init__(self): self.imports: Set[str] = set() @@ -259,6 +441,22 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: self.extracted_classes[node.name.value] = node +class ClassMethodExtractor(cst.CSTVisitor): + def __init__(self): + self.class_methods: Dict[str, List[cst.FunctionDef]] = {} + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + class_name = node.name.value + self.class_methods[class_name] = [] + + for statement in node.body.body: + if isinstance(statement, cst.FunctionDef): + self.class_methods[class_name].append(statement) + + # Sort methods alphabetically by name + self.class_methods[class_name].sort(key=lambda method: method.name.value) + + class MixinClassExtractor(cst.CSTVisitor): def __init__(self): self.extracted_classes: Dict[str, cst.ClassDef] = {} @@ -285,6 +483,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: if self.current_scope_level == 0: self.functions.append(node) + def get_sorted_functions(self) -> List[cst.FunctionDef]: + return sorted(self.functions, key=lambda func: func.name.value) + class GlobalVariableExtractor(cst.CSTVisitor): def __init__(self): @@ -402,6 +603,7 @@ def write_topo_class_files( } for group_name, class_names in class_groups.items(): + module_docstring = f""" build123d topology @@ -442,9 +644,10 @@ def write_topo_class_files( source_tree.visit(variable_collector) group_classes = [ - extracted_classes[name] for name in class_names if name in extracted_classes + sort_class_methods_by_convention(extracted_classes[name]) + for name in class_names + if name in extracted_classes ] - # Add imports for base classes based on layer dependencies additional_imports = [] if group_name != "shape_core": @@ -535,7 +738,7 @@ def write_topo_class_files( body.append(var) body.append(cst.EmptyLine(indent=False)) - for func in function_collector.functions: + for func in function_collector.get_sorted_functions(): if func.name.value in function_source[group_name]: body.append(func) class_module = cst.Module(body=body, header=header)