Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for nested multiversioned functions #8

Closed
calebzulawski opened this issue Oct 7, 2019 · 26 comments
Closed

Add support for nested multiversioned functions #8

calebzulawski opened this issue Oct 7, 2019 · 26 comments

Comments

@calebzulawski
Copy link
Owner

In this example:

#[target_clones("[x86|x86_64]+avx")]
fn foo() {  /* snip */ }

#[target_clones("[x86|x86_64]+avx")]
fn bar() { foo(); }

foo should be statically dispatched when invoked in bar, since the CPU features have already been established when dispatching bar. It would also be nice if this even worked when functions have mismatched feature sets (x86+sse+avx should be able to statically dispatch x86+sse functions).

@calebzulawski
Copy link
Owner Author

calebzulawski commented Oct 7, 2019

An idea that I've been toying with is exposing the dispatcher to the module with a signature like fn (&'static [&'static str]) -> dispatched_fn_type. The input slice of strings would contain available features.

This would work in tandem with a new attribute static_dispatch, which would embed a nested function that calls the function statically rather than dynamically. Example usage:

#[target_clones("[x86|x86_64]+avx")]
fn foo() {  /* snip */ }

#[target_clones("[x86|x86_64]+avx")]
#[static_dispatch(foo)]
fn bar() { foo(); }

The generated code for bar would look something like this:

fn __dispatch_bar(&'static [&'static str]) -> fn() -> () {
    fn __clone_0() {
        /* embedded via static_dispatch */
        fn foo() {
            __dispatch_foo(&["avx"])();
        }
        /* the original definition of bar */
        foo();
    }
    fn __clone_1() {
        /* embedded via static_dispatch */
        fn foo() {
            __dispatch_foo(&[])();
        }
        /* the original definition of bar */
        foo();
    }
    /* statically dispatch clones here */
}

fn bar() {
    /* dynamically dispatch here via __dispatch_bar */
}

It appears that const propagation optimization is good enough to turn the static dispatching into inlined functions (!) based on this test: https://rust.godbolt.org/z/gg-XQu

@TethysSvensson, any opinion on this implementation? My only qualm with it is that it exposes an implementation detail (__dispatch_foo etc), but I'm fairly certain that's necessary if we're not wrapping everything in a single function-like macro. I think I'm okay with that as long as long as we give it a name that's unlikely to cause problems, and limit its visibility to pub(crate) at most.

@TethysSvensson
Copy link
Contributor

TethysSvensson commented Oct 8, 2019

I'm not completely convinced that having x86+sse+avx dispatch to x86+sse is a good idea.

What if foo comes in two variaties, x86+feature1 and x86+feature2 and we try to dispatch to it from a function with x86+feature1+feature2. What would be the expected behavior in that case?

If we can live with having a compile-time error in this case, then we can do something where this:

#[target_clones("x86_64+avx")]
fn foo() { /* snip */ }

#[target_clones("x86_64+avx")]
#[static_dispatch(foo)]
pub fn bar() { foo(); }

Turns into this:

fn foo() { /* same dispatcher logic as currently */ }
mod foo {
    use super::*;
    #[target_feature(enable = "avx")]
    pub(super) unsafe fn x86_64_avx() {
        fn foo() { unsafe { foo::x86_64_avx() } }
        fn safe_wrapper() { /* snip */ }
        safe_wrapper();
    }

    pub(super) unsafe fn default_impl() {
        fn foo() { unsafe { foo::default_impl() } }
        fn safe_wrapper() { /* snip */ }
        safe_wrapper();
    }
}

pub fn bar() { /* same dispatcher logic as currently */ }
pub mod bar {
    use super::*;
    #[target_feature(enable = "avx")]
    pub unsafe fn x86_64_avx() {
        fn foo() { unsafe { foo::x86_64_avx() } }
        fn bar() { unsafe { bar::x86_64_avx() } }
        fn safe_wrapper() { foo(); }
        safe_wrapper();
    }

    pub unsafe fn default_impl() {
        fn foo() { unsafe { foo::default_impl() } }
        fn bar() { unsafe { bar::default_impl() } }
        fn safe_wrapper() { foo(); }
        safe_wrapper();
    }
}

This is very close to what we are already doing, except the paths and visibility has been changed.

@calebzulawski
Copy link
Owner Author

That was my original idea too, it's definitely very similar to what we have now. You make a good point about a feature mismatch where multiple choices are valid. It sounds like recommended usage should be to use the same feature sets on every interacting function in the crate.

I think part of the reason I wanted something more complicated was because the list of clones could get very long, but that's definitely a different issue (I've opened #10 to try to deal with that).

One issue that I'm just thinking of--what if foo is a unsafe fn but bar is just fn? I'm not sure we should be embedding fn foo() { unsafe { foo::default_impl() } }, but rather unsafe foo() { foo::default_impl() } }. Maybe in this case the attribute should be #[static_dispatch(unsafe foo)]?

@TethysSvensson
Copy link
Contributor

TethysSvensson commented Oct 9, 2019

How about this then?

pub: The functions inside the module all have the same pub specifier as the original function. If the original function does not have one, they get pub(super). The module will get the same pub specifier as the function.
unsafe: The functions with target_feature are all unsafe (by requirement). The default_impl has the same safe/unsafe status as the original function.
imports: Inside the target_feature, we will create wrapper functions that makes the functions safe as above. However inside the default_impl, we will just do use super::foo::default_impl, to get a function with the correct safe/unsafe semantics.

As far as I can tell, this solves the problem. If a safe function bar tries to call an unsafe function foo using #[static_dispatch(foo)], then bar::default_impl will import foo::default_impl and try to call it, giving the desired compilation error.

@calebzulawski
Copy link
Owner Author

calebzulawski commented Oct 9, 2019

I agree with your comments on pub and unsafe, and came to the same conclusion. One unexpected problem is that since we don't have the type of the nested function, we aren't able to create an inner function that static dispatches it:

mod foo {
    use super::*;
    pub fn default() {
        fn bar(/* how do we determine these args? */) -> /* or the return type? */ {
            bar::default(args)
        }    
    }
}

Since we're in a proc macro, I think the way to go will be to parse the body and replace all instances of bar with bar::default (or whichever implementation). Something that might be even a little cleaner is to apply the attribute directly to the function call:

#[target_clones("x86_64+avx")]
fn foo() {
    let baz = #[static_dispatch] bar();
}

@TethysSvensson
Copy link
Contributor

I am not sure that solution actually solves since the target_feature functions must be marked as unsafe and we still do not know the argument count and types inside the other functions.

I think this could be solved with a sufficient number of wrappers, but it becomes quite cumbersome at some point.

@TethysSvensson
Copy link
Contributor

We could for instance do it something like this:

#[target_clones("x86_64+avx")]
fn foo(x: i32, y: i32) -> i32 {
    if x <= 0 {
        y
    } else {
        1 + foo(x - 1, y)
    }
}

#[target_clones("x86_64+avx")]
#[static_dispatch(foo)]
fn bar(x: i32) -> i32 {
    foo(x, x)
}

Becomes:

fn foo(x: i32, y: i32) -> i32 { /* dispatcher logic */ }
mod foo {
    type FnType = fn(i32, i32) -> i32;

    #[inline(always)]
    pub(super) unsafe fn avx() -> FnType {
        #[target_feature(enable = "avx")]
        unsafe fn avx(x: i32, y: i32) -> i32 {
            #[inline(always)]
            fn safe(x: i32, y: i32) -> i32 {
                if x <= 0 {
                    y
                } else {
                    1 + (unsafe { super::foo::avx() })(x - 1, y)
                }
            }
            safe(x, y)
        }

        #[inline(always)]
        fn foo(x: i32, y: i32) -> i32 {
            unsafe { avx(x, y) }
        }

        foo
    }

    #[inline(always)]
    pub(super) unsafe fn default_impl() -> FnType {
        unsafe fn default_impl(x: i32, y: i32) -> i32 {
            #[inline(always)]
            fn safe(x: i32, y: i32) -> i32 {
                if x <= 0 {
                    y
                } else {
                    1 + (unsafe { super::foo::default_impl() })(x - 1, y)
                }
            }
            safe(x, y)
        }

        #[inline(always)]
        fn foo(x: i32, y: i32) -> i32 {
            unsafe { default_impl(x, y) }
        }

        foo
    }
}


fn bar(x: i32) -> i32 { /* dispatcher logic */ }
mod bar {
    type FnType = fn(i32) -> i32;

    #[inline(always)]
    unsafe fn avx() -> FnType {
        #[target_feature(enable = "avx")]
        unsafe fn avx(x: i32) -> i32 {
            #[inline(always)]
            fn safe(x: i32) -> i32 {
                (unsafe { super::foo::avx() })(x, x)
            }
            safe(x)
        }

        #[inline(always)]
        fn bar(x: i32) -> i32 {
            unsafe { avx(x) }
        }

        bar
    }

    #[inline(always)]
    unsafe fn default_impl() -> FnType {
        unsafe fn default_impl(x: i32) -> i32 {
            #[inline(always)]
            fn safe(x: i32) -> i32 {
                (unsafe { super::foo::default_impl() })(x, x)
            }
            safe(x)
        }

        #[inline(always)]
        fn bar(x: i32) -> i32 {
            unsafe { default_impl(x) }
        }

        bar
    }
}

The idea here is to go through the body and replace every instance of foo with (unsafe { super::foo::relevant_impl() }).

This is really ugly IMO. However I have thought about it for a while and have not yet come up with anything better that achieves all of the following properties:

  • Does not sacrifice the performance of the dispatcher (i.e. just a single indirect jump on the fast path)
  • Does not mess with safe/unsafe dynamics, i.e. the code should compile with target_features if and only if it compiles without.
  • Does not expose any way to call the feature-specific code without either writing unsafe or going through the dispatcher.
  • Does not require a lot of boilerplate for the user.

If you think this is the right approach, I think I have time for implementing it. That is, once we agree on what we want. On the other hand, if you are looking for something to hack on, don't let me stop you. 😉

@calebzulawski
Copy link
Owner Author

calebzulawski commented Oct 9, 2019

I agree that replacing the invoked function at the call site is probably the only easy (for users, not for us) and safe way to do this. A few comments:

  • Since we don't know the safety of the call site, I think the replacement should actually be:
    {
        #[allow(unused_unsafe)]
        unsafe { foo(/* args */) }
    }
  • I think it might be a better idea to apply the attribute at the call site instead of function wide. For example, this could go poorly if foo is replaced globally:
    #[target_clones("x86_64+avx")]
    fn foo(x: i32) -> i32 { x * x }
    
    #[target_clones("x86_64+avx")]
     fn bar(x: i32) -> i32 {
         let foo = |x| foo(x + 1);
         foo(x)
     }
    Syn's AST is pretty huge, but I think we can crawl it, look for all instances of syn::ExprPath, and parse its attributes for #[static_dispatch] and make replacements as necessary.
  • We need to remember to keep leading components of paths when making the replacement, e.g. crate::foo::bar should become crate::foo::bar::relevant_impl()

I've done a bit of work in the feature/static-dispatch branch in preparation for this, but haven't gotten to the actual static dispatch component yet. I've split the target clones out into a module, and I've also canonicalized the names by sorting and deduping the features (x86+sse+avx, x86+avx+sse, x86+sse+avx+sse now all produce a function named features_avx_sse). Additionally, I've reduced the number of nested functions to two by producing something like this:

#[target_features(enable = "avx")]
unsafe fn avx(/* args */) {
    #[inline(always)]
    fn foo() { /* body */ }
    foo(/* args */)
}

fn default(/* args */) {
    #[inline(always)]
    fn foo() { /* body */}
    foo(/* args */)
}

I believe this should have the same safety guarantee.

If you'd like, we can merge my changes to master before you look at the actual static dispatch component.

@TethysSvensson
Copy link
Contributor

Doing it that way will not allow optimizations to work for recursive functions, because foo will never be inlined into something that has the target_fetures attribute.

@TethysSvensson
Copy link
Contributor

And I believe that Syn has some visitor functionality to go through the AST recursively for you and only look for specific things.

I see your point about global replacement. I am not sure I can think of a better solution than the one you propose.

@calebzulawski
Copy link
Owner Author

I'm not sure I understand the inlining issue, in this example it seems to inline just fine: https://rust.godbolt.org/z/DAk5P5. That said, I'm not terribly concerned about it and I'll revert that change if that third function is necessary.

If Syn has visitor functionality that would be great. One thought about call site replacement is that something like this doesn't work:

let f = #[static_dispatch] foo;

I figure if the user absolutely needs that, however, they can use a closure.

@calebzulawski
Copy link
Owner Author

Actually, now that I think about it, that example probably works fine as long as you don't specify foo as safe.

@TethysSvensson
Copy link
Contributor

I am pretty sure I had an issue with inlining previously using something similar to what you propose. However I cannot replicate the problem right now, so it's probably fine.

@TethysSvensson
Copy link
Contributor

Why wouldn't that example work?

@TethysSvensson
Copy link
Contributor

{
    #[allow(unused_unsafe)]
    unsafe { foo(/* args */) }
}

I think this is a bad idea, as it allows the user to have unsafe code in the arguments without causing compilation issues, e.g. foo(transmute::<usize, &[u8]>(1))

@calebzulawski
Copy link
Owner Author

Very good point. Returning the function pointer is probably a better way to do that.

@TethysSvensson
Copy link
Contributor

I think I found the counter example now. This used to optimize correctly, but no longer does:

#[multiversion::target_clones("x86_64+avx")]
pub fn square(i: i32, x: &mut [f32]) {
    if i <= 1 {
        for v in x {
            *v *= *v;
        }
    } else {
        square(i - 1, &mut x[1..]);
        square(i - 2, &mut x[2..]);
    }
}

@TethysSvensson
Copy link
Contributor

In the current master, the innermost function will no longer be inlined, despite the #[inline(always)] predicate.

@calebzulawski
Copy link
Owner Author

Looks like you're right, it appears to work in some trivial examples but that's it. I think this is caused by rust-lang/rust#53117. I've added the recursion-helper back.

@TethysSvensson
Copy link
Contributor

I have made a branch implementing this. Feel free to use directly or take partial inspiration from it. It does what we agreed upon.

The main downside is that it currently breaks the multiversion!() macro, since the macro does not create a module with the corresponding functions. It probably should though if we want static dispatch to also work for these functions.

@calebzulawski
Copy link
Owner Author

I've done some testing and unfortunately I'm not sure this method is going to work. The compiler has a really hard time keeping track of the CPU features and inlining.

If the returned function is transmuted from the unsafe fn pointer as you currently have in your branch, the function isn't inlined when static-dispatched. If the function is wrapped one more time (or the recursion helper is returned, these seem to have the same effect), the function is now properly inlined when static-dispatched, but somehow the CPU features are lost in the version used by the dynamic dispatcher.

I think a better solution would be to drop the indirection via function pointers and go back to exposing the functions directly. Since we can't transmute at the call site (for static dispatch), I think we might be able to use a an unsafe block with extra care to make sure the arguments are evaluated outside the unsafe block. Using syn we can count the number of arguments and make something like:

{
    let __arg_0 = /* some expr */;
    let __arg_1 = /* some other expr */;
    {
        #[allow(unused_unsafe)]
        unsafe { foo(__arg_0, __arg_1) }
    }
}

@calebzulawski
Copy link
Owner Author

I'm now remembering that this solution doesn't work properly if the original function is unsafe...

@TethysSvensson
Copy link
Contributor

Maybe we can hack something up using rust-lang/rust#64035?

E.g. having a macro that generates an inline(always) wrapper with the correct signature in some other scope?

@calebzulawski
Copy link
Owner Author

calebzulawski commented Oct 23, 2019

I think something like that would definitely work. I wonder if we could even just make a single macro that produces all of the various versions (or std::compile_error if the specific feature combination doesn't exist), to prevent having to expose a whole bunch of functions inside a module. A #[static_dispatch] attribute would just be some syntactic sugar on top of that.

I actually really like the idea of being able to report with std::compile_error if you try to static-dispatch into a function that doesn't have the correct feature combination.

Did you have a specific idea of how it would work?

@calebzulawski
Copy link
Owner Author

calebzulawski commented Nov 2, 2019

For the record, I think I discovered why it didn't work. The compiler needs to put "breaks" before code that may require feature detection, in order to prevent accidentally performing speculative execution on an unsupported function before feature detection is complete. I believe this is why inlining was lost when transmuting a function pointer (I suppose speculative execution stops at a call/jmp), but not when calling the function in a function that already has the features enabled.

@calebzulawski
Copy link
Owner Author

Added in #12.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants