Import loops in pytorch/torchgen/model.py

Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain.

In this tutorial, we’ll explore how to identify and fix problematic import cycles using Codegen.

You can find the complete example code in our examples repository.

Overview

The steps to identify and fix import loops are as follows:

  1. Detect import loops
  2. Visualize them
  3. Identify problematic cycles with mixed static/dynamic imports
  4. Fix these cycles using Codegen

Step 1: Detect Import Loops

  • Create a graph
  • Loop through imports in the codebase and add edges between the import files
  • Find strongly connected components using Networkx (the import loops)
G = nx.MultiDiGraph()

# Add all edges to the graph
for imp in codebase.imports:
    if imp.from_file and imp.to_file:
        edge_color = "red" if imp.is_dynamic else "black"
        edge_label = "dynamic" if imp.is_dynamic else "static"

        # Store the import statement and its metadata
        G.add_edge(
            imp.to_file.filepath,
            imp.from_file.filepath,
            color=edge_color,
            label=edge_label,
            is_dynamic=imp.is_dynamic,
            import_statement=imp,  # Store the whole import object
            key=id(imp.import_statement),
        )
# Find strongly connected components
cycles = [scc for scc in nx.strongly_connected_components(G) if len(scc) > 1]

print(f"🔄 Found {len(cycles)} import cycles:")
for i, cycle in enumerate(cycles, 1):
    print(f"\nCycle #{i}:")
    print(f"Size: {len(cycle)} files")

    # Create subgraph for this cycle to count edges
    cycle_subgraph = G.subgraph(cycle)

    # Count total edges
    total_edges = cycle_subgraph.number_of_edges()
    print(f"Total number of imports in cycle: {total_edges}")

    # Count dynamic and static imports separately
    dynamic_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "red")
    static_imports = sum(1 for u, v, data in cycle_subgraph.edges(data=True) if data.get("color") == "black")

    print(f"Number of dynamic imports: {dynamic_imports}")
    print(f"Number of static imports: {static_imports}")

Understanding Import Cycles

Not all import cycles are problematic! Here’s an example of a cycle that one may think would cause an error but it does not because due to using dynamic imports.

# top level import in in APoT_tensor.py
from quantizer.py import objectA
# dynamic import in quantizer.py
def some_func():
    # dynamic import (evaluated when some_func() is called)
    from APoT_tensor.py import objectB

A dynamic import is an import defined inside of a function, method or any executable body of code which delays the import execution until that function, method or body of code is called.

You can use imp.is_dynamic to check if the import is dynamic allowing you to investigate imports that are handled more intentionally.

Step 2: Visualize Import Loops

  • Create a new subgraph to visualize one cycle
  • color and label the edges based on their type (dynamic/static)
  • visualize the cycle graph using codebase.visualize(graph)
cycle = cycles[0]

def create_single_loop_graph(cycle):
    cycle_graph = nx.MultiDiGraph()  # Changed to MultiDiGraph to support multiple edges
    cycle = list(cycle)
    for i in range(len(cycle)):
        for j in range(len(cycle)):
            # Get all edges between these nodes from original graph
            edge_data_dict = G.get_edge_data(cycle[i], cycle[j])
            if edge_data_dict:
                # For each edge between these nodes
                for edge_key, edge_data in edge_data_dict.items():
                    # Add edge with all its attributes to cycle graph
                    cycle_graph.add_edge(cycle[i], cycle[j], **edge_data)
    return cycle_graph


cycle_graph = create_single_loop_graph(cycle)
codebase.visualize(cycle_graph)

Import loops in pytorch/torchgen/model.py

Step 3: Identify problematic cycles with mixed static & dynamic imports

The import loops that we are really concerned about are those that have mixed static/dynamic imports.

Here’s an example of a problematic cycle that we want to fix:

# In flex_decoding.py
from .flex_attention import (
    compute_forward_block_mn,
    compute_forward_inner,
    # ... more static imports
)

# Also in flex_decoding.py
def create_flex_decoding_kernel(*args, **kwargs):
    from .flex_attention import set_head_dim_values  # dynamic import

It’s clear that there is both a top level and a dynamic import that imports from the same module. Thus, this can cause issues if not handled carefully.

Let’s find these problematic cycles:

def find_problematic_import_loops(G, sccs):
    """Find cycles where files have both static and dynamic imports between them."""
    problematic_cycles = []

    for i, scc in enumerate(sccs):
        if i == 2:  # skipping the second import loop as it's incredibly long (it's also invalid)
            continue
        mixed_import_files = {}  # (from_file, to_file) -> {dynamic: count, static: count}

        # Check all file pairs in the cycle
        for from_file in scc:
            for to_file in scc:
                if G.has_edge(from_file, to_file):
                    # Get all edges between these files
                    edges = G.get_edge_data(from_file, to_file)

                    # Count imports by type
                    dynamic_count = sum(1 for e in edges.values() if e["color"] == "red")
                    static_count = sum(1 for e in edges.values() if e["color"] == "black")

                    # If we have both types between same files, this is problematic
                    if dynamic_count > 0 and static_count > 0:
                        mixed_import_files[(from_file, to_file)] = {"dynamic": dynamic_count, "static": static_count, "edges": edges}

        if mixed_import_files:
            problematic_cycles.append({"files": scc, "mixed_imports": mixed_import_files, "index": i})

    # Print findings
    print(f"Found {len(problematic_cycles)} cycles with mixed imports:")
    for i, cycle in enumerate(problematic_cycles):
        print(f"\n⚠️  Problematic Cycle #{i + 1}:")
        print(f"\n⚠️  Index #{cycle['index']}:")
        print(f"Size: {len(cycle['files'])} files")

        for (from_file, to_file), data in cycle["mixed_imports"].items():
            print("\n📁 Mixed imports detected:")
            print(f"  From: {from_file}")
            print(f"  To:   {to_file}")
            print(f"  Dynamic imports: {data['dynamic']}")
            print(f"  Static imports: {data['static']}")

    return problematic_cycles

problematic_cycles = find_problematic_import_loops(G, cycles)

Step 4: Fix the loop by moving the shared symbols to a separate utils.py file

One common fix to this problem to break this cycle is to move all the shared symbols to a separate utils.py file. We can do this using the method symbol.move_to_file:

# Create new utils file
utils_file = codebase.create_file("torch/_inductor/kernel/flex_utils.py")

# Get the two files involved in the import cycle
decoding_file = codebase.get_file("torch/_inductor/kernel/flex_decoding.py")
attention_file = codebase.get_file("torch/_inductor/kernel/flex_attention.py")
attention_file_path = "torch/_inductor/kernel/flex_attention.py"
decoding_file_path = "torch/_inductor/kernel/flex_decoding.py"

# Track symbols to move
symbols_to_move = set()

# Find imports from flex_attention in flex_decoding
for imp in decoding_file.imports:
    if imp.from_file and imp.from_file.filepath == attention_file_path:
        # Get the actual symbol from flex_attention
        if imp.imported_symbol:
            symbols_to_move.add(imp.imported_symbol)

# Move identified symbols to utils file
for symbol in symbols_to_move:
    symbol.move_to_file(utils_file)

print(f"🔄 Moved {len(symbols_to_move)} symbols to flex_utils.py")
for symbol in symbols_to_move:
    print(symbol.name)
# run this command to have the changes take effect in the codebase
codebase.commit()

Next Steps Verify all tests pass after the migration and fix other problematic import loops using the suggested strategies:

  1. Move the shared symbols to a separate file
  2. If a module needs imports only for type hints, consider using if TYPE_CHECKING from the typing module
  3. Use lazy imports using importlib to load imports dynamically