Skip to content

Commit

Permalink
Add torch.export test file (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 30, 2024
1 parent fbc1344 commit f17e406
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 8 deletions.
39 changes: 35 additions & 4 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -9249,7 +9249,7 @@ python.Execution = class {
this.kind = kind;
this.arg = arg;
this.target = target;
this.persistent = persistent;
this.persistent = persistent || null;
}
});
torch.export.graph_signature.OutputKind = {
Expand Down Expand Up @@ -9616,6 +9616,24 @@ python.Execution = class {
this.tensor_constant_name = obj.tensor_constant_name;
}
});
this.registerType('torch._export.serde.schema.InputToConstantInputSpec', class {
constructor(obj) {
this.name = obj.name;
this.value = new torch._export.serde.schema.ConstantValue(obj.value);
}
});
this.registerType('torch._export.serde.schema.ConstantValue', class extends torch._export.serde.union._Union {
constructor(obj) {
super(obj);
if (this.type === 'as_int' || this.type === 'as_float' || this.type === 'as_bool' || this.type === 'as_string' || this.type === 'as_strings') {
// continue
} else if (this.type === 'as_none') {
this.as_none = null;
} else {
throw new python.Error(`Unsupported constant value type '${this.type}'.`);
}
}
});
this.registerType('torch._export.serde.schema.InputSpec', class extends torch._export.serde.union._Union {
constructor(obj) {
super(obj);
Expand All @@ -9627,14 +9645,14 @@ python.Execution = class {
this.buffer = new torch._export.serde.schema.InputToBufferSpec(this.buffer);
} else if (this.type === 'tensor_constant') {
this.tensor_constant = new torch._export.serde.schema.InputToTensorConstantSpec(this.tensor_constant);
} else if (this.type === 'constant_input') {
this.constant_input = new torch._export.serde.schema.InputToConstantInputSpec(this.constant_input);
} else {
throw new python.Error(`Unsupported input spec type '${this.type}'.`);
}
/*
tensor_constant: InputToTensorConstantSpec
custom_obj: InputToCustomObjSpec
token: InputTokenSpec
constant_input: ConstantInputSpec
*/
}
});
Expand Down Expand Up @@ -9915,6 +9933,20 @@ python.Execution = class {
}
throw new python.Error(`Unknown input spec ${i}`);
}
deserialize_constant_input(inp) {
if (inp.type === 'as_int') {
return inp.as_int;
} else if (inp.type === 'as_float') {
return inp.as_float;
} else if (inp.type === 'as_string') {
return inp.as_string;
} else if (inp.type === 'as_bool') {
return inp.as_bool;
} else if (inp.type === 'as_none') {
return null;
}
throw new python.Error(`Unhandled constant argument ${inp} to deserialize.`);
}
deserialize_output_spec(o) {
if (o.type === 'user_output') {
return new torch.export.graph_signature.OutputSpec(
Expand Down Expand Up @@ -10045,7 +10077,6 @@ python.Execution = class {
value.name,
name=value.name,
)*/
throw new Error();
} else if (typ_ === 'as_device') {
return this.deserialize_device(inp.as_device);
} else if (typ_ === 'as_int') {
Expand Down
10 changes: 6 additions & 4 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,12 @@ pytorch.Graph = class {
this.nodes.push(node);
}
for (const input_spec of exported_program.graph_signature.user_inputs()) {
const node = nodes.get(input_spec);
const value = values.map(node);
const argument = new pytorch.Argument(input_spec, [value]);
this.inputs.push(argument);
if (nodes.has(input_spec)) {
const node = nodes.get(input_spec);
const value = values.map(node);
const argument = new pytorch.Argument(input_spec, [value]);
this.inputs.push(argument);
}
}
/*
for (const output_spec of exported_program.graph_signature.user_outputs()) {
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5384,6 +5384,13 @@
"format": "PyTorch v0.1.10",
"link": "https://github.com/babajide07/Redundant-Feature-Pruning-Pytorch-Implementation"
},
{
"type": "pytorch",
"target": "constant_input.pt2",
"source": "https://github.com/user-attachments/files/17966651/constant_input.pt2.zip[constant_input.pt2]",
"format": "PyTorch Export v7.3",
"link": "https://github.com/lutzroeder/netron/issues/1211"
},
{
"type": "pytorch",
"target": "cpu_jit.pt",
Expand Down

0 comments on commit f17e406

Please sign in to comment.