diff --git a/tools/refactor_topo.py b/tools/refactor_topo.py index 51a6b6f9..4adbc0bb 100644 --- a/tools/refactor_topo.py +++ b/tools/refactor_topo.py @@ -333,51 +333,6 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self.extracted_methods.append(renamed_node) -class OptionalTransformer(cst.CSTTransformer): - def __init__(self): - super().__init__() - self.requires_optional_import = False # Tracks if `Optional` import is needed - - def leave_AnnAssign( - self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign - ) -> cst.AnnAssign: - # Check if the annotation has a default value of `None` - if ( - isinstance(updated_node.value, cst.Name) - and updated_node.value.value == "None" - ): - # Wrap the annotation type in `Optional` - if updated_node.annotation: - self.requires_optional_import = True - new_annotation = cst.Subscript( - value=cst.Name("Optional"), - slice=[ - cst.SubscriptElement( - slice=cst.Index(updated_node.annotation.annotation) - ) - ], - ) - return updated_node.with_changes( - annotation=cst.Annotation(new_annotation) - ) - return updated_node - - def leave_Module( - self, original_node: cst.Module, updated_node: cst.Module - ) -> cst.Module: - # Add the `Optional` import if required - if self.requires_optional_import: - import_stmt = cst.ImportFrom( - module=cst.Name("typing"), - names=[cst.ImportAlias(name=cst.Name("Optional"))], - ) - new_body = [cst.SimpleStatementLine([import_stmt])] + list( - updated_node.body - ) - return updated_node.with_changes(body=new_body) - return updated_node - - def write_topo_class_files( source_tree: cst.Module, extracted_classes: Dict[str, cst.ClassDef], @@ -809,12 +764,10 @@ def main(): # Parse source file and collect imports source_tree = cst.parse_module(topo_file.read_text()) + source_tree = source_tree.visit(UnionToPipeTransformer()) + # transformed_module = source_tree.visit(UnionToPipeTransformer()) + # print(transformed_module.code) - # Apply transformations - source_tree = source_tree.visit(UnionToPipeTransformer()) # Existing transformation - source_tree = source_tree.visit(OptionalTransformer()) # New Optional conversion - - # Collect imports collector = ImportCollector() source_tree.visit(collector) @@ -829,6 +782,8 @@ def main(): # Extract functions function_collector = StandaloneFunctionAndVariableCollector() source_tree.visit(function_collector) + # for f in function_collector.functions: + # print(f.name.value) # Write the class files write_topo_class_files( @@ -838,8 +793,11 @@ def main(): output_dir=output_dir, ) - # Clean up imports + # Create a Rope project instance + # project = Project(str(script_dir)) project = Project(str(output_dir)) + + # Clean up imports for file in output_dir.glob("*.py"): if file.name == "__init__.py": continue