Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge environments of nested functions #3718

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 122 additions & 61 deletions src/libexpr/eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void printValue(std::ostream & str, std::set<const Value *> & active, const Valu
break;
case tThunk:
case tApp:
case tPartialApp:
str << "<CODE>";
break;
case tLambda:
Expand Down Expand Up @@ -1275,35 +1276,28 @@ void EvalState::callFunction(Value & fun, size_t nrArgs, Value * * args, Value &
}
};

while (nrArgs > 0) {

if (vCur.isLambda()) {

ExprLambda & lambda(*vCur.lambda.fun);

auto size =
(lambda.arg.empty() ? 0 : 1) +
(lambda.hasFormals() ? lambda.formals->formals.size() : 0);
Env & env2(allocEnv(size));
env2.up = vCur.lambda.env;
auto callLambda = [&](Env * env, ExprLambda & lambda, Value * * args)
{
Env & env2(allocEnv(lambda.envSize));
env2.up = env;

Displacement displ = 0;
Displacement displ = 0;

if (!lambda.hasFormals())
env2.values[displ++] = args[0];
for (auto & arg : lambda.args) {
auto vArg = *args++;

else {
forceAttrs(*args[0], pos);
if (arg.arg != sEpsilon)
env2.values[displ++] = vArg;

if (!lambda.arg.empty())
env2.values[displ++] = args[0];
if (arg.formals) {
forceAttrs(*vArg, pos);

/* For each formal argument, get the actual argument. If
there is no matching actual argument but the formal
argument has a default, use the default. */
size_t attrsUsed = 0;
for (auto & i : lambda.formals->formals) {
auto j = args[0]->attrs->get(i.name);
for (auto & i : arg.formals->formals) {
auto j = vArg->attrs->get(i.name);
if (!j) {
if (!i.def) throwTypeError(pos, "%1% called without required argument '%2%'",
lambda, i.name);
Expand All @@ -1316,35 +1310,96 @@ void EvalState::callFunction(Value & fun, size_t nrArgs, Value * * args, Value &

/* Check that each actual argument is listed as a formal
argument (unless the attribute match specifies a `...'). */
if (!lambda.formals->ellipsis && attrsUsed != args[0]->attrs->size()) {
if (!arg.formals->ellipsis && attrsUsed != vArg->attrs->size()) {
/* Nope, so show the first unexpected argument to the
user. */
for (auto & i : *args[0]->attrs)
if (lambda.formals->argNames.find(i.name) == lambda.formals->argNames.end())
for (auto & i : *vArg->attrs)
if (arg.formals->argNames.find(i.name) == arg.formals->argNames.end())
throwTypeError(pos, "%1% called with unexpected argument '%2%'", lambda, i.name);
abort(); // can't happen
}
}
}

assert(displ == lambda.envSize);

nrFunctionCalls++;
if (countCalls) incrFunctionCall(&lambda);

/* Evaluate the body. */
try {
lambda.body->eval(*this, env2, vCur);
} catch (Error & e) {
if (loggerSettings.showTrace) {
addErrorTrace(e, lambda.pos, "while evaluating %s",
(lambda.name.set()
? "'" + (string) lambda.name + "'"
: "anonymous lambda"));
addErrorTrace(e, pos, "from call site%s", "");
}
throw;
}
};

while (nrArgs > 0) {

if (vCur.isLambda()) {

nrFunctionCalls++;
if (countCalls) incrFunctionCall(&lambda);

/* Evaluate the body. */
try {
lambda.body->eval(*this, env2, vCur);
} catch (Error & e) {
if (loggerSettings.showTrace.get()) {
addErrorTrace(e, lambda.pos, "while evaluating %s",
(lambda.name.set()
? "'" + (string) lambda.name + "'"
: "anonymous lambda"));
addErrorTrace(e, pos, "from call site%s", "");
ExprLambda & lambda(*vCur.lambda.fun);

if (nrArgs < lambda.args.size()) {
vRes = vCur;
for (size_t i = 0; i < nrArgs; ++i) {
auto fun2 = allocValue();
*fun2 = vRes;
vRes.mkPartialApp(fun2, args[i]);
}
throw;
return;
} else {
callLambda(vCur.lambda.env, lambda, args);
nrArgs -= lambda.args.size();
args += lambda.args.size();
}
}

nrArgs--;
args += 1;
else if (vCur.isPartialApp()) {
/* Figure out the number of arguments still needed. */
size_t argsDone = 0;
Value * lambda = &vCur;
while (lambda->isPartialApp()) {
argsDone++;
lambda = lambda->app.left;
}
assert(lambda->isLambda());
auto arity = lambda->lambda.fun->args.size();
auto argsLeft = arity - argsDone;

if (nrArgs < argsLeft) {
/* We still don't have enough arguments, so extend the tPartialApp chain. */
vRes = vCur;
for (size_t i = 0; i < nrArgs; ++i) {
auto fun2 = allocValue();
*fun2 = vRes;
vRes.mkPartialApp(fun2, args[i]);
}
return;
} else {
/* We have all the arguments, so call the function
with the previous and new arguments. */

Value * vArgs[arity];
auto n = argsDone;
for (Value * arg = &vCur; arg->isPartialApp(); arg = arg->app.left)
vArgs[--n] = arg->app.right;

for (size_t i = 0; i < argsLeft; ++i)
vArgs[argsDone + i] = args[i];

nrArgs -= argsLeft;
args += argsLeft;

callLambda(lambda->lambda.env, *lambda->lambda.fun, vArgs);
}
}

else if (vCur.isPrimOp()) {
Expand Down Expand Up @@ -1458,42 +1513,48 @@ void EvalState::autoCallFunction(Bindings & args, Value & fun, Value & res)
}
}

if (!fun.isLambda() || !fun.lambda.fun->hasFormals()) {
if (!fun.isLambda()) {
res = fun;
return;
}

Value * actualArgs = allocValue();
mkAttrs(*actualArgs, std::max(static_cast<uint32_t>(fun.lambda.fun->formals->formals.size()), args.size()));
Value * actualArgs[fun.lambda.fun->args.size()];

if (fun.lambda.fun->formals->ellipsis) {
// If the formals have an ellipsis (eg the function accepts extra args) pass
// all available automatic arguments (which includes arguments specified on
// the command line via --arg/--argstr)
for (auto& v : args) {
actualArgs->attrs->push_back(v);
for (const auto & [i, arg] : enumerate(fun.lambda.fun->args)) {
if (!arg.formals) {
res = fun;
return;
}
} else {
// Otherwise, only pass the arguments that the function accepts
for (auto & i : fun.lambda.fun->formals->formals) {
Bindings::iterator j = args.find(i.name);
if (j != args.end()) {
actualArgs->attrs->push_back(*j);
} else if (!i.def) {
throwMissingArgumentError(i.pos, R"(cannot evaluate a function that has an argument without a value ('%1%')

actualArgs[i] = allocValue();
mkAttrs(*actualArgs[i], std::max(arg.formals->formals.size(), static_cast<size_t>(args.size())));

if (arg.formals->ellipsis) {
/* If the formals have an ellipsis (i.e. the function
accepts extra args), pass all available automatic
arguments. */
for (auto & v : args)
actualArgs[i]->attrs->push_back(v);
} else {
/* Otherwise, only pass the arguments that the function
accepts. */
for (auto & j : arg.formals->formals) {
if (auto attr = args.get(j.name))
actualArgs[i]->attrs->push_back(*attr);
else if (!j.def)
throwMissingArgumentError(j.pos, R"(cannot evaluate a function that has an argument without a value ('%1%')

Nix attempted to evaluate a function as a top level expression; in
this case it must have its arguments supplied either by default
values, or passed explicitly with '--arg' or '--argstr'. See
https://nixos.org/manual/nix/stable/#ss-functions.)", i.name);

https://nixos.org/manual/nix/stable/#ss-functions.)", j.name);
}
}
}

actualArgs->attrs->sort();
actualArgs[i]->attrs->sort();
}

callFunction(fun, *actualArgs, res, noPos);
callFunction(fun, fun.lambda.fun->args.size(), actualArgs, res, noPos);
}


Expand Down
9 changes: 7 additions & 2 deletions src/libexpr/flake/flake.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,13 @@ static Flake getFlake(
if (auto outputs = vInfo.attrs->get(sOutputs)) {
expectType(state, nFunction, *outputs->value, *outputs->pos);

if (outputs->value->isLambda() && outputs->value->lambda.fun->hasFormals()) {
for (auto & formal : outputs->value->lambda.fun->formals->formals) {
if (outputs->value->lambda.fun->args.size() != 1)
throw Error("the 'outputs' attribute of flake '%s' is not a unary function", lockedRef);

auto & arg = outputs->value->lambda.fun->args[0];

if (arg.formals) {
for (auto & formal : arg.formals->formals) {
if (formal.name != state.sSelf)
flake.inputs.emplace(formal.name, FlakeInput {
.ref = parseFlakeRef(formal.name)
Expand Down
81 changes: 53 additions & 28 deletions src/libexpr/nixexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,23 +124,26 @@ void ExprList::show(std::ostream & str) const
void ExprLambda::show(std::ostream & str) const
{
str << "(";
if (hasFormals()) {
str << "{ ";
bool first = true;
for (auto & i : formals->formals) {
if (first) first = false; else str << ", ";
str << i.name;
if (i.def) str << " ? " << *i.def;
}
if (formals->ellipsis) {
if (!first) str << ", ";
str << "...";
for (auto & arg : args) {
if (arg.formals) {
str << "{ ";
bool first = true;
for (auto & i : arg.formals->formals) {
if (first) first = false; else str << ", ";
str << i.name;
if (i.def) str << " ? " << *i.def;
}
if (arg.formals->ellipsis) {
if (!first) str << ", ";
str << "...";
}
str << " }";
if (!arg.arg.empty()) str << " @ ";
}
str << " }";
if (!arg.empty()) str << " @ ";
if (!arg.arg.empty()) str << arg.arg;
str << ": ";
}
if (!arg.empty()) str << arg;
str << ": " << *body << ")";
str << *body << ")";
}

void ExprCall::show(std::ostream & str) const
Expand Down Expand Up @@ -279,8 +282,7 @@ void ExprVar::bindVars(const StaticEnv & env)
if (curEnv->isWith) {
if (withLevel == -1) withLevel = level;
} else {
auto i = curEnv->find(name);
if (i != curEnv->vars.end()) {
if (auto i = curEnv->get(name)) {
fromWith = false;
this->level = level;
displ = i->second;
Expand Down Expand Up @@ -354,25 +356,48 @@ void ExprList::bindVars(const StaticEnv & env)

void ExprLambda::bindVars(const StaticEnv & env)
{
StaticEnv newEnv(
false, &env,
(hasFormals() ? formals->formals.size() : 0) +
(arg.empty() ? 0 : 1));
/* The parser adds arguments in reverse order. Let's fix that
now. */
std::reverse(args.begin(), args.end());

envSize = 0;

for (auto & arg :args) {
if (!arg.arg.empty()) envSize++;
if (arg.formals) envSize += arg.formals->formals.size();
}

StaticEnv newEnv(false, &env, envSize);

Displacement displ = 0;

if (!arg.empty()) newEnv.vars.emplace_back(arg, displ++);
for (auto & arg : args) {
if (!arg.arg.empty()) {
if (auto i = const_cast<StaticEnv::Vars::value_type *>(newEnv.get(arg.arg)))
i->second = displ++;
else
newEnv.vars.emplace_back(arg.arg, displ++);
}

if (hasFormals()) {
for (auto & i : formals->formals)
newEnv.vars.emplace_back(i.name, displ++);
if (arg.formals) {
for (auto & i : arg.formals->formals) {
if (auto j = const_cast<StaticEnv::Vars::value_type *>(newEnv.get(i.name)))
j->second = displ++;
else
newEnv.vars.emplace_back(i.name, displ++);
}

newEnv.sort();
newEnv.sort();

for (auto & i : formals->formals)
if (i.def) i.def->bindVars(newEnv);
for (auto & i : arg.formals->formals)
if (i.def) i.def->bindVars(newEnv);
}
}

assert(displ == envSize);

newEnv.sort();

body->bindVars(newEnv);
}

Expand Down
Loading