Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ir: add syntactic support for vscale, vscale_range #1121

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ir/attrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ ostream& operator<<(ostream &os, const FnAttrs &attr) {
os << ", " << attr.allocsize_1;
os << ')';
}
if (attr.vscaleRange) {
auto [low, high] = *attr.vscaleRange;
os << " vscale_range(" << low << ", " << high << ')';
}

attr.fp_denormal.print(os);
if (attr.fp_denormal32)
Expand Down
2 changes: 2 additions & 0 deletions ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class FnAttrs final {
AllocSize = 1 << 12, ZeroExt = 1<<13,
SignExt = 1<<14, NoFPClass = 1<<15, Asm = 1<<16 };

std::optional<std::pair<uint16_t, uint16_t>> vscaleRange;

FnAttrs(unsigned bits = None) : bits(bits) {}

bool has(Attribute a) const { return (bits & a) != 0; }
Expand Down
22 changes: 11 additions & 11 deletions ir/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ StateValue IntConst::toSMT(State &s) const {
return { expr::mkInt(get<string>(val).c_str(), bits()), true };
}

expr IntConst::getTypeConstraints() const {
expr IntConst::getTypeConstraints(const Function &f) const {
unsigned min_bits = 0;
if (auto v = get_if<int64_t>(&val))
min_bits = (*v >= 0 ? 63 : 64) - num_sign_bits(*v);

return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceIntType() &&
getType().sizeVar().uge(min_bits);
}
Expand Down Expand Up @@ -86,8 +86,8 @@ FloatConst::FloatConst(Type &type, string val, bool bit_value)
: Constant(type, bit_value ? int_to_readable_float(type, val) : val),
val(std::move(val)), bit_value(bit_value) {}

expr FloatConst::getTypeConstraints() const {
return Value::getTypeConstraints() &&
expr FloatConst::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints(f) &&
getType().enforceFloatType();
}

Expand All @@ -108,8 +108,8 @@ StateValue ConstantInput::toSMT(State &s) const {
return { expr::mkVar(getName().c_str(), type), true };
}

expr ConstantInput::getTypeConstraints() const {
return Value::getTypeConstraints() &&
expr ConstantInput::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints(f) &&
(getType().enforceIntType() || getType().enforceFloatType());
}

Expand Down Expand Up @@ -157,8 +157,8 @@ StateValue ConstantBinOp::toSMT(State &s) const {
return { std::move(val), ap && bp };
}

expr ConstantBinOp::getTypeConstraints() const {
return Value::getTypeConstraints() &&
expr ConstantBinOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints(f) &&
getType().enforceIntType() &&
getType() == lhs.getType() &&
getType() == rhs.getType();
Expand Down Expand Up @@ -210,10 +210,10 @@ StateValue ConstantFn::toSMT(State &s) const {
return { std::move(r), true };
}

expr ConstantFn::getTypeConstraints() const {
expr r = Value::getTypeConstraints();
expr ConstantFn::getTypeConstraints(const Function &f) const {
expr r = Value::getTypeConstraints(f);
for (auto a : args) {
r &= a->getTypeConstraints();
r &= a->getTypeConstraints(f);
}

Type &ty = getType();
Expand Down
10 changes: 5 additions & 5 deletions ir/constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class IntConst final : public Constant {
IntConst(Type &type, int64_t val);
IntConst(Type &type, std::string &&val);
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints() const override;
smt::expr getTypeConstraints(const Function &f) const override;
auto getInt() const { return std::get_if<int64_t>(&val); }
};

Expand All @@ -38,7 +38,7 @@ class FloatConst final : public Constant {
FloatConst(Type &type, std::string val, bool bit_value);

StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints() const override;
smt::expr getTypeConstraints(const Function &f) const override;
};


Expand All @@ -47,7 +47,7 @@ class ConstantInput final : public Constant {
ConstantInput(Type &type, std::string &&name)
: Constant(type, std::move(name)) {}
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints() const override;
smt::expr getTypeConstraints(const Function &f) const override;
};


Expand All @@ -62,7 +62,7 @@ class ConstantBinOp final : public Constant {
public:
ConstantBinOp(Type &type, Constant &lhs, Constant &rhs, Op op);
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints() const override;
smt::expr getTypeConstraints(const Function &f) const override;
};


Expand All @@ -73,7 +73,7 @@ class ConstantFn final : public Constant {
public:
ConstantFn(Type &type, std::string_view name, std::vector<Value*> &&args);
StateValue toSMT(State &s) const override;
smt::expr getTypeConstraints() const override;
smt::expr getTypeConstraints(const Function &f) const override;
};

struct ConstantFnException {
Expand Down
2 changes: 1 addition & 1 deletion ir/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ expr Function::getTypeConstraints() const {
}
for (auto &l : { getConstants(), getInputs(), getUndefs() }) {
for (auto &v : l) {
t &= v.getTypeConstraints();
t &= v.getTypeConstraints(*this);
}
}
return t;
Expand Down
56 changes: 28 additions & 28 deletions ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ uint64_t getGlobalVarSize(const IR::Value *V) {

namespace IR {

expr Instr::getTypeConstraints() const {
expr Instr::getTypeConstraints(const Function &f) const {
UNREACHABLE();
return {};
}
Expand Down Expand Up @@ -596,7 +596,7 @@ expr BinOp::getTypeConstraints(const Function &f) const {
getType() == rhs->getType();
break;
}
return Value::getTypeConstraints() && std::move(instrconstr);
return Value::getTypeConstraints(f) && std::move(instrconstr);
}

unique_ptr<Instr> BinOp::dup(Function &f, const string &suffix) const {
Expand Down Expand Up @@ -958,7 +958,7 @@ StateValue FpBinOp::toSMT(State &s) const {
}

expr FpBinOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceFloatOrVectorType() &&
getType() == lhs->getType() &&
getType() == rhs->getType();
Expand Down Expand Up @@ -1086,7 +1086,7 @@ expr UnaryOp::getTypeConstraints(const Function &f) const {
break;
}

return Value::getTypeConstraints() && std::move(instrconstr);
return Value::getTypeConstraints(f) && std::move(instrconstr);
}

static Value* dup_aggregate(Function &f, Value *val) {
Expand Down Expand Up @@ -1213,7 +1213,7 @@ StateValue FpUnaryOp::toSMT(State &s) const {
}

expr FpUnaryOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType() == val->getType() &&
getType().enforceFloatOrVectorType();
}
Expand Down Expand Up @@ -1286,7 +1286,7 @@ StateValue UnaryReductionOp::toSMT(State &s) const {
}

expr UnaryReductionOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceIntType() &&
val->getType().enforceVectorType(
[this](auto &scalar) { return scalar == getType(); });
Expand Down Expand Up @@ -1405,7 +1405,7 @@ expr TernaryOp::getTypeConstraints(const Function &f) const {
getType().enforceIntOrVectorType();
break;
}
return Value::getTypeConstraints() && instrconstr;
return Value::getTypeConstraints(f) && instrconstr;
}

unique_ptr<Instr> TernaryOp::dup(Function &f, const string &suffix) const {
Expand Down Expand Up @@ -1486,7 +1486,7 @@ StateValue FpTernaryOp::toSMT(State &s) const {
}

expr FpTernaryOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType() == a->getType() &&
getType() == b->getType() &&
getType() == c->getType() &&
Expand Down Expand Up @@ -1557,7 +1557,7 @@ StateValue TestOp::toSMT(State &s) const {
}

expr TestOp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
lhs->getType().enforceFloatOrVectorType() &&
rhs->getType().enforceIntType(32) &&
getType().enforceIntOrVectorType(1) &&
Expand Down Expand Up @@ -1721,7 +1721,7 @@ expr ConversionOp::getTypeConstraints(const Function &f) const {
break;
}

c &= Value::getTypeConstraints();
c &= Value::getTypeConstraints(f);
if (op != BitCast)
c &= getType().enforceVectorTypeEquiv(val->getType());
return c;
Expand Down Expand Up @@ -1965,7 +1965,7 @@ expr FpConversionOp::getTypeConstraints(const Function &f) const {
val->getType().scalarSize().ugt(getType().scalarSize());
break;
}
return Value::getTypeConstraints() && c;
return Value::getTypeConstraints(f) && c;
}

unique_ptr<Instr> FpConversionOp::dup(Function &f, const string &suffix) const {
Expand Down Expand Up @@ -2027,7 +2027,7 @@ StateValue Select::toSMT(State &s) const {
}

expr Select::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
cond->getType().enforceIntOrVectorType(1) &&
getType().enforceVectorTypeIff(cond->getType()) &&
(fmath.isNone() ? expr(true) : getType().enforceFloatOrVectorType()) &&
Expand Down Expand Up @@ -2080,7 +2080,7 @@ StateValue ExtractValue::toSMT(State &s) const {
}

expr ExtractValue::getTypeConstraints(const Function &f) const {
auto c = Value::getTypeConstraints() &&
auto c = Value::getTypeConstraints(f) &&
val->getType().enforceAggregateType();

Type *type = &val->getType();
Expand Down Expand Up @@ -2172,7 +2172,7 @@ StateValue InsertValue::toSMT(State &s) const {
}

expr InsertValue::getTypeConstraints(const Function &f) const {
auto c = Value::getTypeConstraints() &&
auto c = Value::getTypeConstraints(f) &&
val->getType().enforceAggregateType() &&
val->getType() == getType();

Expand Down Expand Up @@ -2646,7 +2646,7 @@ StateValue FnCall::toSMT(State &s) const {

expr FnCall::getTypeConstraints(const Function &f) const {
// TODO : also need to name each arg type smt var uniquely
expr ret = Value::getTypeConstraints();
expr ret = Value::getTypeConstraints(f);
if (fnptr)
ret &= fnptr->getType().enforcePtrType();
return ret;
Expand Down Expand Up @@ -2809,7 +2809,7 @@ StateValue ICmp::toSMT(State &s) const {
}

expr ICmp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceIntOrVectorType(1) &&
getType().enforceVectorTypeEquiv(a->getType()) &&
a->getType().enforceIntOrPtrOrVectorType() &&
Expand Down Expand Up @@ -2908,7 +2908,7 @@ StateValue FCmp::toSMT(State &s) const {
}

expr FCmp::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceIntOrVectorType(1) &&
getType().enforceVectorTypeEquiv(a->getType()) &&
a->getType().enforceFloatOrVectorType() &&
Expand Down Expand Up @@ -2968,7 +2968,7 @@ StateValue Freeze::toSMT(State &s) const {
}

expr Freeze::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType() == val->getType();
}

Expand Down Expand Up @@ -3080,7 +3080,7 @@ StateValue Phi::toSMT(State &s) const {
}

expr Phi::getTypeConstraints(const Function &f) const {
auto c = Value::getTypeConstraints();
auto c = Value::getTypeConstraints(f);
for (auto &[val, bb] : values) {
c &= val->getType() == getType();
}
Expand Down Expand Up @@ -3324,7 +3324,7 @@ StateValue Return::toSMT(State &s) const {
}

expr Return::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType() == val->getType() &&
f.getType() == getType();
}
Expand Down Expand Up @@ -3711,7 +3711,7 @@ StateValue Alloc::toSMT(State &s) const {
}

expr Alloc::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforcePtrType() &&
size->getType().enforceIntType();
}
Expand Down Expand Up @@ -3967,7 +3967,7 @@ StateValue GEP::toSMT(State &s) const {
}

expr GEP::getTypeConstraints(const Function &f) const {
auto c = Value::getTypeConstraints() &&
auto c = Value::getTypeConstraints(f) &&
getType().enforceVectorTypeIff(ptr->getType()) &&
getType().enforcePtrOrVectorType();
for (auto &[sz, idx] : idxs) {
Expand Down Expand Up @@ -4052,7 +4052,7 @@ StateValue PtrMask::toSMT(State &s) const {
}

expr PtrMask::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
ptr->getType().enforcePtrOrVectorType() &&
getType() == ptr->getType() &&
mask->getType().enforceIntOrVectorType() &&
Expand Down Expand Up @@ -4101,7 +4101,7 @@ StateValue Load::toSMT(State &s) const {
}

expr Load::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
ptr->getType().enforcePtrType();
}

Expand Down Expand Up @@ -4583,7 +4583,7 @@ StateValue Strlen::toSMT(State &s) const {
}

expr Strlen::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
ptr->getType().enforcePtrType() &&
getType().enforceIntType();
}
Expand Down Expand Up @@ -4886,7 +4886,7 @@ StateValue ExtractElement::toSMT(State &s) const {
}

expr ExtractElement::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
v->getType().enforceVectorType([&](auto &ty)
{ return ty == getType(); }) &&
idx->getType().enforceIntType();
Expand Down Expand Up @@ -4929,7 +4929,7 @@ StateValue InsertElement::toSMT(State &s) const {
}

expr InsertElement::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType() == v->getType() &&
v->getType().enforceVectorType([&](auto &ty)
{ return ty == e->getType(); }) &&
Expand Down Expand Up @@ -4984,7 +4984,7 @@ StateValue ShuffleVector::toSMT(State &s) const {
}

expr ShuffleVector::getTypeConstraints(const Function &f) const {
return Value::getTypeConstraints() &&
return Value::getTypeConstraints(f) &&
getType().enforceVectorTypeSameChildTy(v1->getType()) &&
getType().getAsAggregateType()->numElements() == mask.size() &&
v1->getType().enforceVectorType() &&
Expand Down
3 changes: 1 addition & 2 deletions ir/instr.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class Instr : public Value {
virtual bool propagatesPoison() const = 0;
virtual bool hasSideEffects() const = 0;
virtual bool isTerminator() const;
smt::expr getTypeConstraints() const override;
virtual smt::expr getTypeConstraints(const Function &f) const = 0;
smt::expr getTypeConstraints(const Function &f) const override;
virtual std::unique_ptr<Instr> dup(Function &f,
const std::string &suffix) const = 0;
};
Expand Down
Loading
Loading