Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 62 additions & 23 deletions hiddenlayer/pytorch_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
HiddenLayer

PyTorch graph importer.

Written by Waleed Abdulla
Licensed under the MIT License
"""
Expand All @@ -18,7 +18,7 @@
# Hide onnx: prefix
ht.Rename(op=r"onnx::(.*)", to=r"\1"),
# ONNX uses Gemm for linear layers (stands for General Matrix Multiplication).
# It's an odd name that noone recognizes. Rename it.
# It's an odd name that no one recognizes. Rename it.
ht.Rename(op=r"Gemm", to=r"Linear"),
# PyTorch layers that don't have an ONNX counterpart
ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"),
Expand Down Expand Up @@ -49,10 +49,9 @@ def get_shape(torch_node):
"""Return the output shape of the given Pytorch node."""
# Extract node output shape from the node string representation
# This is a hack because there doesn't seem to be an official way to do it.
# See my quesiton in the PyTorch forum:
# See question in the PyTorch forum:
# https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2
# TODO: find a better way to extract output shape
# TODO: Assuming the node has one output. Update if we encounter a multi-output node.
# Assuming the node has one output. Update if we encounter a multi-output node.
m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs())))
if m:
shape = m.group(1)
Expand All @@ -64,34 +63,74 @@ def get_shape(torch_node):


def import_graph(hl_graph, model, args, input_names=None, verbose=False):
# TODO: add input names to graph

# Run the Pytorch graph to get a trace and generate a graph from it
"""
Build a hiddenlayer Graph from a PyTorch JIT trace.
"""
# 1) Get trace graph
trace, out = torch.jit._get_trace_graph(model, args)
torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)

# Dump list of nodes (DEBUG only)
# 2) Comment out or remove the _optimize_trace call:
# torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
torch_graph = trace # Use trace directly

# Debug: optionally dump the list of nodes
if verbose:
dump_pytorch_graph(torch_graph)

# Loop through nodes and build HL graph
# 3) Traverse the PyTorch graph and build hiddenlayer nodes/edges
for torch_node in torch_graph.nodes():
# Op
# a) Operator kind
op = torch_node.kind()
# Parameters
params = {k: torch_node[k] for k in torch_node.attributeNames()}
# Inputs/outputs
# TODO: inputs = [i.unique() for i in node.inputs()]

# b) Gather attributes (fix 'torch_node[k]' error)
# Use kindOf(k) + corresponding accessor
params = {}
for k in torch_node.attributeNames():
kind = torch_node.kindOf(k)
if kind == "f":
params[k] = torch_node.f(k)
elif kind == "i":
params[k] = torch_node.i(k)
elif kind == "s":
params[k] = torch_node.s(k)
elif kind == "t":
# e.g. tensor attribute
params[k] = str(torch_node.t(k)) # or more specialized logic
elif kind == "fs":
params[k] = torch_node.fs(k)
elif kind == "is":
params[k] = torch_node.is_(k)
elif kind == "ss":
params[k] = torch_node.ss(k)
else:
# If there's an unrecognized type, store the type name
params[k] = f"<{kind}>"

# c) Node outputs
outputs = [o.unique() for o in torch_node.outputs()]
# Get output shape

# d) Infer shape from outputs
shape = get_shape(torch_node)
# Add HL node
hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op,
output_shape=shape, params=params)

# e) Create HL node
hl_node = Node(
uid=pytorch_id(torch_node),
name=None,
op=op,
output_shape=shape,
params=params
)
hl_graph.add_node(hl_node)
# Add edges

# f) Link edges to next nodes that consume these outputs
for target_torch_node in torch_graph.nodes():
target_inputs = [i.unique() for i in target_torch_node.inputs()]
# If any output from this node is in target's input => link them
if set(outputs) & set(target_inputs):
hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
return hl_graph
hl_graph.add_edge_by_id(
pytorch_id(torch_node),
pytorch_id(target_torch_node),
shape
)

return hl_graph