From f17e4063ffa0dafeac8f577ce92b1b48b0aa5f9f Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 30 Nov 2024 11:46:59 -0800 Subject: [PATCH] Add torch.export test file (#1211) --- source/python.js | 39 +++++++++++++++++++++++++++++++++++---- source/pytorch.js | 10 ++++++---- test/models.json | 7 +++++++ 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/source/python.js b/source/python.js index 86495b944b..042af0c494 100644 --- a/source/python.js +++ b/source/python.js @@ -9249,7 +9249,7 @@ python.Execution = class { this.kind = kind; this.arg = arg; this.target = target; - this.persistent = persistent; + this.persistent = persistent || null; } }); torch.export.graph_signature.OutputKind = { @@ -9616,6 +9616,24 @@ python.Execution = class { this.tensor_constant_name = obj.tensor_constant_name; } }); + this.registerType('torch._export.serde.schema.InputToConstantInputSpec', class { + constructor(obj) { + this.name = obj.name; + this.value = new torch._export.serde.schema.ConstantValue(obj.value); + } + }); + this.registerType('torch._export.serde.schema.ConstantValue', class extends torch._export.serde.union._Union { + constructor(obj) { + super(obj); + if (this.type === 'as_int' || this.type === 'as_float' || this.type === 'as_bool' || this.type === 'as_string' || this.type === 'as_strings') { + // continue + } else if (this.type === 'as_none') { + this.as_none = null; + } else { + throw new python.Error(`Unsupported constant value type '${this.type}'.`); + } + } + }); this.registerType('torch._export.serde.schema.InputSpec', class extends torch._export.serde.union._Union { constructor(obj) { super(obj); @@ -9627,14 +9645,14 @@ python.Execution = class { this.buffer = new torch._export.serde.schema.InputToBufferSpec(this.buffer); } else if (this.type === 'tensor_constant') { this.tensor_constant = new torch._export.serde.schema.InputToTensorConstantSpec(this.tensor_constant); + } else if (this.type === 'constant_input') { + this.constant_input = new torch._export.serde.schema.InputToConstantInputSpec(this.constant_input); } else { throw new python.Error(`Unsupported input spec type '${this.type}'.`); } /* - tensor_constant: InputToTensorConstantSpec custom_obj: InputToCustomObjSpec token: InputTokenSpec - constant_input: ConstantInputSpec */ } }); @@ -9915,6 +9933,20 @@ python.Execution = class { } throw new python.Error(`Unknown input spec ${i}`); } + deserialize_constant_input(inp) { + if (inp.type === 'as_int') { + return inp.as_int; + } else if (inp.type === 'as_float') { + return inp.as_float; + } else if (inp.type === 'as_string') { + return inp.as_string; + } else if (inp.type === 'as_bool') { + return inp.as_bool; + } else if (inp.type === 'as_none') { + return null; + } + throw new python.Error(`Unhandled constant argument ${inp} to deserialize.`); + } deserialize_output_spec(o) { if (o.type === 'user_output') { return new torch.export.graph_signature.OutputSpec( @@ -10045,7 +10077,6 @@ python.Execution = class { value.name, name=value.name, )*/ - throw new Error(); } else if (typ_ === 'as_device') { return this.deserialize_device(inp.as_device); } else if (typ_ === 'as_int') { diff --git a/source/pytorch.js b/source/pytorch.js index cddf863b99..8eab62fe9b 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -296,10 +296,12 @@ pytorch.Graph = class { this.nodes.push(node); } for (const input_spec of exported_program.graph_signature.user_inputs()) { - const node = nodes.get(input_spec); - const value = values.map(node); - const argument = new pytorch.Argument(input_spec, [value]); - this.inputs.push(argument); + if (nodes.has(input_spec)) { + const node = nodes.get(input_spec); + const value = values.map(node); + const argument = new pytorch.Argument(input_spec, [value]); + this.inputs.push(argument); + } } /* for (const output_spec of exported_program.graph_signature.user_outputs()) { diff --git a/test/models.json b/test/models.json index f11266bf63..1fd7133912 100644 --- a/test/models.json +++ b/test/models.json @@ -5384,6 +5384,13 @@ "format": "PyTorch v0.1.10", "link": "https://github.com/babajide07/Redundant-Feature-Pruning-Pytorch-Implementation" }, + { + "type": "pytorch", + "target": "constant_input.pt2", + "source": "https://github.com/user-attachments/files/17966651/constant_input.pt2.zip[constant_input.pt2]", + "format": "PyTorch Export v7.3", + "link": "https://github.com/lutzroeder/netron/issues/1211" + }, { "type": "pytorch", "target": "cpu_jit.pt",