Skip to content

Commit

Permalink
Revert "Added Optional to input parameters"
Browse files Browse the repository at this point in the history
This reverts commit 87c046b.
  • Loading branch information
gumyr committed Dec 17, 2024
1 parent 87c046b commit 127d048
Showing 1 changed file with 9 additions and 51 deletions.
60 changes: 9 additions & 51 deletions tools/refactor_topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,51 +333,6 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.extracted_methods.append(renamed_node)


class OptionalTransformer(cst.CSTTransformer):
def __init__(self):
super().__init__()
self.requires_optional_import = False # Tracks if `Optional` import is needed

def leave_AnnAssign(
self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign
) -> cst.AnnAssign:
# Check if the annotation has a default value of `None`
if (
isinstance(updated_node.value, cst.Name)
and updated_node.value.value == "None"
):
# Wrap the annotation type in `Optional`
if updated_node.annotation:
self.requires_optional_import = True
new_annotation = cst.Subscript(
value=cst.Name("Optional"),
slice=[
cst.SubscriptElement(
slice=cst.Index(updated_node.annotation.annotation)
)
],
)
return updated_node.with_changes(
annotation=cst.Annotation(new_annotation)
)
return updated_node

def leave_Module(
self, original_node: cst.Module, updated_node: cst.Module
) -> cst.Module:
# Add the `Optional` import if required
if self.requires_optional_import:
import_stmt = cst.ImportFrom(
module=cst.Name("typing"),
names=[cst.ImportAlias(name=cst.Name("Optional"))],
)
new_body = [cst.SimpleStatementLine([import_stmt])] + list(
updated_node.body
)
return updated_node.with_changes(body=new_body)
return updated_node


def write_topo_class_files(
source_tree: cst.Module,
extracted_classes: Dict[str, cst.ClassDef],
Expand Down Expand Up @@ -809,12 +764,10 @@ def main():

# Parse source file and collect imports
source_tree = cst.parse_module(topo_file.read_text())
source_tree = source_tree.visit(UnionToPipeTransformer())
# transformed_module = source_tree.visit(UnionToPipeTransformer())
# print(transformed_module.code)

# Apply transformations
source_tree = source_tree.visit(UnionToPipeTransformer()) # Existing transformation
source_tree = source_tree.visit(OptionalTransformer()) # New Optional conversion

# Collect imports
collector = ImportCollector()
source_tree.visit(collector)

Expand All @@ -829,6 +782,8 @@ def main():
# Extract functions
function_collector = StandaloneFunctionAndVariableCollector()
source_tree.visit(function_collector)
# for f in function_collector.functions:
# print(f.name.value)

# Write the class files
write_topo_class_files(
Expand All @@ -838,8 +793,11 @@ def main():
output_dir=output_dir,
)

# Clean up imports
# Create a Rope project instance
# project = Project(str(script_dir))
project = Project(str(output_dir))

# Clean up imports
for file in output_dir.glob("*.py"):
if file.name == "__init__.py":
continue
Expand Down

0 comments on commit 127d048

Please sign in to comment.