diff --git a/.gitignore b/.gitignore index 416741a5e704..0da5320d840d 100644 --- a/.gitignore +++ b/.gitignore @@ -169,4 +169,4 @@ tests/mxnet_unit_tests # generated wrappers for ccache cc -cxx +cxx \ No newline at end of file diff --git a/Jenkinsfile b/Jenkinsfile index 10fdf1d6cfab..48d68ef11460 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -702,6 +702,18 @@ try { } } }, + 'Clojure: CPU': { + node('mxnetlinux-cpu') { + ws('workspace/ut-clojure-cpu') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + unpack_lib('cpu', mx_dist_lib) + docker_run('ubuntu_cpu', 'unittest_ubuntu_cpu_clojure', false) + publish_test_coverage() + } + } + } + }, 'Perl: CPU': { node('mxnetlinux-cpu') { ws('workspace/ut-perl-cpu') { diff --git a/ci/docker/Dockerfile.build.ubuntu_cpu b/ci/docker/Dockerfile.build.ubuntu_cpu index 57cf1e93e542..598b9af86ad8 100755 --- a/ci/docker/Dockerfile.build.ubuntu_cpu +++ b/ci/docker/Dockerfile.build.ubuntu_cpu @@ -30,6 +30,8 @@ COPY install/ubuntu_python.sh /work/ RUN /work/ubuntu_python.sh COPY install/ubuntu_scala.sh /work/ RUN /work/ubuntu_scala.sh +COPY install/ubuntu_clojure.sh /work/ +RUN /work/ubuntu_clojure.sh COPY install/ubuntu_r.sh /work/ RUN /work/ubuntu_r.sh COPY install/ubuntu_perl.sh /work/ diff --git a/ci/docker/install/ubuntu_clojure.sh b/ci/docker/install/ubuntu_clojure.sh new file mode 100755 index 000000000000..c1a6b7f06d9f --- /dev/null +++ b/ci/docker/install/ubuntu_clojure.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# build and install are separated so changes to build don't invalidate +# the whole docker cache for the image + +set -ex +# install libraries for mxnet's clojure package on ubuntu +echo 'Installing Clojure...' + +wget https://raw.githubusercontent.com/technomancy/leiningen/stable/bin/lein +chmod 775 lein +sudo cp lein /usr/local/bin diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index e49639903f92..083ddf57fddc 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -652,6 +652,13 @@ unittest_ubuntu_gpu_scala() { make scalatest USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 SCALA_TEST_ON_GPU=1 USE_DIST_KVSTORE=1 } +unittest_ubuntu_cpu_clojure() { + set -ex + make scalapkg USE_OPENCV=1 USE_BLAS=openblas USE_DIST_KVSTORE=1 + make scalainstall USE_OPENCV=1 USE_BLAS=openblas USE_DIST_KVSTORE=1 + ./contrib/clojure-package/ci-test.sh +} + unittest_ubuntu_cpugpu_perl() { set -ex ./perl-package/test.sh diff --git a/contrib/clojure-package/.gitignore b/contrib/clojure-package/.gitignore new file mode 100644 index 000000000000..f634b900921a --- /dev/null +++ b/contrib/clojure-package/.gitignore @@ -0,0 +1,43 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ +data/* +model/* +*~ +*.params +*.states +*.json +examples/module/data/* +examples/module/target/* +examples/rnn/data/char_lstm.zip +examples/rnn/data/obama.txt +examples/pre-trained-models/caltech-256/caltech-256-60-train.rec +examples/pre-trained-models/caltech-256/caltech-256-60-val.rec +examples/pre-trained-models/model/synset.txt +examples/pre-trained-models/test-image.jpg +examples/imclassification/data/* +examples/gan/data/* +examples/gan/results/* +examples/cnn-text-classification/data/glove/* +examples/cnn-text-classification/data/mr-data/* +examples/multi-label/data/mnist.zip +examples/multi-label/data/t10k-images-idx3-ubyte +examples/multi-label/data/t10k-labels-idx1-ubyte +examples/multi-label/data/train-images-idx3-ubyte +examples/multi-label/data/train-labels-idx1-ubyte +examples/visualization/test-vis/* +examples/visualization/test-vis.pdf +.DS_Store +src/.DS_Store +src/org/.DS_Store +test/test-ndarray.clj +test/test-symbol.clj + diff --git a/contrib/clojure-package/LICENSE b/contrib/clojure-package/LICENSE new file mode 100644 index 000000000000..8f71f43fee3f --- /dev/null +++ b/contrib/clojure-package/LICENSE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/contrib/clojure-package/README.md b/contrib/clojure-package/README.md new file mode 100644 index 000000000000..fb31df4c35b0 --- /dev/null +++ b/contrib/clojure-package/README.md @@ -0,0 +1,230 @@ +# Clojure MXNet + +A clojure package to the MXNet Deep Learning library + +## Introduction + +MXNet is a first class, modern deep learning library. It supports multiple languages on a first class basis and is incubating as an Apache project. + +The motivation for creating a Clojure package is to be able to open the deep learning library to the Clojure ecosystem and build bridges for future development and innovation for the community. It provides all the needed tools including low level and high level apis, dynamic graphs, and things like GAN and natural language support. + +For high leverage, the Clojure package has been built on the existing Scala package using interop. This has allowed rapid development and close parity with the Scala functionality. This also leaves the door open to directly developing code against the jni-bindings with Clojure in the future in an incremental fashion, using the test suites as a refactoring guide. + +## Current State and Plans + +The Clojure package is nearing the end of its first development milestone which is to achieve a close parity with the Scala package. + +Help is needed testing and generally making the package better. A list of the pacakge status and contribution needs can be found here [Clojure Package Contribution Needs](https://cwiki.apache.org/confluence/display/MXNET/Clojure+Package+Contribution+Needs). Please get involved :) + +Testing instructions can be found in the testing.md. + +## Getting Started + +The following systems are supported: + +- OSX cpu +- Linux cpu +- Linux gpu + +There are two ways of getting going. The first way is the easiest and that is to use the pre-built jars from Maven. The second way is to build from source. In both cases, you will need to load the prereqs and dependencies, (like opencv). + + + +### Prerequisites + + +Follow the instructions from https://mxnet.incubator.apache.org/install/osx_setup.html or https://mxnet.incubator.apache.org/install/ubuntu_setup.html +about _Prepare Environment for GPU Installation_ +and _Install MXNet dependencies_ + + +#### Cloning the repo and running from source + +To use the prebuilt jars (easiest), you will need to replace the native version of the line in the project dependencies with your configuration. + +`[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.2.1"]` +or +`[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.2.1"]` +or +`[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.2.1"]` + +If you are using the prebuilt jars they may have a slightly different dependencies then building from source: + +*For OSX you will need:* + +`brew install opencv` + +*For Ubuntu Linux you will need:* + +``` +sudo add-apt-repository ppa:timsc/opencv-3.4 +sudo apt-get update +sudo apt install libopencv-imgcodecs3.4 +``` + +*For Arch Linux you will need:* + +_CPU_ + +``` +yaourt -S openblas-lapack +yaourt -S libcurl-compat +export LD_PRELOAD=libcurl.so.3 +``` +_GPU_ + +``` +wget https://archive.archlinux.org/packages/c/cuda/cuda-9.0.176-4-x86_64.pkg.tar.xz +sudo pacman -U cuda-9.0.176-4-x86_64.pkg.tar.xz +``` + +If you want to see the exact versions and flags that the jars were built with, look here: +[Scala Release Process](https://cwiki.apache.org/confluence/display/MXNET/MXNet-Scala+Release+Process) + + +Check your installation with `lein test`. If that works alright then, you can try some code! + +```clojure + +(ns tutorial.ndarray + (:require [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.context :as context])) + +;;Create NDArray +(def a (ndarray/zeros [100 50])) ;;all zero arrray of dimension 100 x 50 +(def b (ndarray/ones [256 32 128 1])) ;; all one array of dimension +(def c (ndarray/array [1 2 3 4 5 6] [2 3])) ;; array with contents of a shape 2 x 3 + +;;; There are also ways to convert to a vec or get the shape as an object or vec +(ndarray/->vec c) ;=> [1.0 2.0 3.0 4.0 5.0 6.0] +``` + +See the examples/tutorial section for more. + + +The jars from maven with the needed MXNet native binaries in it. On startup, the native libraries are extracted from the jar and copied into a temporary location on your path. On termination, they are deleted. + + +### Build from MXNET Source + +Checkout the latest sha from the main package + +`git clone --recursive https://github.com/apache/incubator-mxnet.git ~/mxnet` +`cd ~/mxnet` + +If you need to checkout a particular release you can do it with: + +`git checkout tags/1.2.1 -b release-1.2.1` + +`git submodule update --init --recursive` + +Sometimes it useful to use this script to clean hard +https://gist.github.com/nicktoumpelis/11214362 + + +Go here to do the base package installation https://mxnet.incubator.apache.org/install/index.html + + Run `make scalapkg` then `make scalainstall` + +then replace the correct jar for your architecture in the project.clj, example `[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.3.0-SNAPSHOT"]` + +#### Test your installation + +To test your installation, you should run `lein test`. This will run the test suite (CPU) for the clojure package. + + +#### Generation of NDArray and Symbol apis + +The bulk of the ndarray and symbol apis are generated via java reflection into the Scala classes. To generate, use the `dev/generator.clj` file. These generated files are checked in as source, so the only time you would need to run them is if you are updated the clojure package with an updated scala jar and want to regenerate the code. + +To do this run the leiningen task +`lein run generate-code` + +Or load in the repl and use the functions: + +`(generate-ndarray-file)` +and +`(generate-symbol-file)` + + +These will generate the files under `src/org.apache.clojure-mxnet/gen/` that are loaded by the `src/org.apache.clojure-mxnet/ndarray.clj` and `src/org.apache.clojure-mxnet/symbol.clj` files. + + +## Examples +There are quite a few examples in the examples directory. To use. + +`lein install` in the main project +`cd` in the the example project of interest + +There are README is every directory outlining instructions. + +A good place to get started is the module example. +Do `lein run` for the cpu version or `lein run :gpu` for gpu. + +## Generating documentation + +To generate api docs, run `lein codox`. The html docs will be generated in the target/docs directory. + +_Note: There is an error thrown in the generated code due to some loading issues, but the docs are all still there._ + +## Code Coverage + +To run the Code Coverage tool. Run `lein cloverage`. + +## FAQ + + +**Why build on the Scala package?** + +The motivation section addresses this, but the main reason is high leverage is using the great work that the Scala package has already done. + +**How can I tell if the gpu is being used?** + +CUDA is finding a best algorithm... As long as a Context.gpu() passed in the code as a context, GPU should be used. + +This command can be very handy too + +`nvidia-smi --query-gpu=timestamp,name,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used --format=csv -l 5 +timestamp, name, utilization.gpu [%], utilization.memory [%], memory.total [MiB], memory.free [MiB], memory.used [MiB]` + +**Supported APIs** +There are 3 high level apis supported in MxNet: (Model/FeedForward), Module, and Gluon. The Module api is supported in the Clojure package because of the existing support for it in the Scala package. The Module api is very similar to the Gluon api and examples of the usage can be found in the examples directory. The Model/FeedForward Api is deprected. + +Gluon support will come later and may or may not be built on the Scala gluon api (when it lands there) + +## Architecture & Design + +See the Confluence page: https://cwiki.apache.org/confluence/display/MXNET/MXNet+Clojure + +## Building and Deploying Jars +The process to build and deploy the jars currently is a manual process using the `lein` build tool and `Clojars`, the Clojure dependency hosting platform. + +There is one jar for every system supported. + +- Comment out the line in the `project.clj` for the system that you are targeting, (example OSX cpu you would uncomment out ` [org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.2.0"]` but leave the linux deps commented) +- Change the `defproject org.apache.mxnet.contrib.clojure/clojure-mxnet "0.1.1-SNAPSHOT"` in the project to reference the correct version number and jar description. For example changing the line to be `org.apache.mxnet.contrib.clojure/mxnet-osx-cpu "0.1.2"` would create a jar with the group id of `org.apache.mxnet.contrib.clojure` and the artifact name of `mxnet-osx-cpu` and the version of `0.1.2` +- Run `lein clean` +- Run `lein jar` to create the jar +- Check that the jar looks alright in the `/target` directory. + +To deploy the jar to Clojars, you do `lein deploy clojars` and it will prompt you for your username and password. + +_Note: Integration with deployment to Nexus can be enabled too for the future [https://help.sonatype.com/repomanager2/maven-and-other-build-tools/leiningen](https://help.sonatype.com/repomanager2/maven-and-other-build-tools/leiningen)_ + +You would repeat this process for all the build system types. + + +## Special Thanks +Special thanks to people that provided testing and feedback to make this possible + +- Chris Hodapp +- IƱaki Arenaza & Magnet Coop +- r0man +- Ben Kamphaus +- Sivaram Konanki +- Rustam Gilaztdinov +- Kamil Hryniewicz +- Christian Weilbach +- Burin Choomnuan +- Avram Aelony +- Jim Dunn diff --git a/contrib/clojure-package/ci-test.sh b/contrib/clojure-package/ci-test.sh new file mode 100755 index 000000000000..eda3919f5ce0 --- /dev/null +++ b/contrib/clojure-package/ci-test.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +MXNET_HOME=${PWD} +cd ${MXNET_HOME}/contrib/clojure-package +lein test diff --git a/contrib/clojure-package/doc/getting-started/Archlinux.md b/contrib/clojure-package/doc/getting-started/Archlinux.md new file mode 100644 index 000000000000..daa8aa8bcd82 --- /dev/null +++ b/contrib/clojure-package/doc/getting-started/Archlinux.md @@ -0,0 +1,5 @@ +## Getting Started on Arch Linux + +There are a few steps to get up and running with the dependencies on Ubuntu. + +Please see this issue for a handy guide: [Running on Arch Linux](https://github.com/gigasquid/clojure-mxnet/issues/1) diff --git a/contrib/clojure-package/doc/getting-started/Ubuntu.md b/contrib/clojure-package/doc/getting-started/Ubuntu.md new file mode 100644 index 000000000000..b9117d2505d6 --- /dev/null +++ b/contrib/clojure-package/doc/getting-started/Ubuntu.md @@ -0,0 +1,8 @@ +## Getting Started on Ubuntu Linux + +There are a few steps to get up and running with the dependencies on Ubuntu. + +See this for reference: + +https://github.com/apache/incubator-mxnet/issues/11303 + diff --git a/contrib/clojure-package/doc/intro.md b/contrib/clojure-package/doc/intro.md new file mode 100644 index 000000000000..a3c155f7bf68 --- /dev/null +++ b/contrib/clojure-package/doc/intro.md @@ -0,0 +1,12 @@ +# Clojure MXNet + +A clojure package to the MXNet Deep Learning library + +## Introduction + +MXNet is a first class, modern deep learning library that AWS has officially picked as its chosen library. It supports multiple languages on a first class basis and is incubating as an Apache project. + +The motivation for creating a Clojure package is to be able to open the deep learning library to the Clojure ecosystem and build bridges for future development and innovation for the community. It provides all the needed tools including low level and high level apis, dynamic graphs, and things like GAN and natural language support. + +For high leverage, the Clojure package has been built on the existing Scala package using interop. This has allowed rapid development and close parity with the Scala functionality. This also leaves the door open to directly developing code against the jni-bindings with Clojure in the future in an incremental fashion, using the test suites as a refactoring guide. + diff --git a/contrib/clojure-package/examples/cnn-text-classification/.gitignore b/contrib/clojure-package/examples/cnn-text-classification/.gitignore new file mode 100644 index 000000000000..c53038ec0e3d --- /dev/null +++ b/contrib/clojure-package/examples/cnn-text-classification/.gitignore @@ -0,0 +1,11 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ diff --git a/contrib/clojure-package/examples/cnn-text-classification/README.md b/contrib/clojure-package/examples/cnn-text-classification/README.md new file mode 100644 index 000000000000..86a8abb06e7a --- /dev/null +++ b/contrib/clojure-package/examples/cnn-text-classification/README.md @@ -0,0 +1,33 @@ +# cnn-text-classification + +An example of text classification using CNN + +To use you must download the MR polarity dataset and put it in the path specified in the mr-dataset-path +The dataset can be obtained here: [https://github.com/yoonkim/CNN_sentence](https://github.com/yoonkim/CNN_sentence). The two files `rt-polarity.neg` +and `rt-polarity.pos` must be put in a directory. For example, `data/mr-data/rt-polarity.neg`. + +You also must download the glove word embeddings. The suggested one to use is the smaller 50 dimension one +`glove.6B.50d.txt` which is contained in the download file here [https://nlp.stanford.edu/projects/glove/](https://nlp.stanford.edu/projects/glove/) + +## Usage + +You can run through the repl with +`(train-convnet {:embedding-size 50 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000})` + +or +`JVM_OPTS="Xmx1g" lein run` (cpu) + +You can control the devices you run on by doing: + +`lein run :cpu 2` - This will run on 2 cpu devices +`lein run :gpu 1` - This will run on 1 gpu device +`lein run :gpu 2` - This will run on 2 gpu devices + + +The max-examples only loads 1000 each of the dataset to keep the time and memory down. To run all the examples, +change the main to be (train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10) + +and then run + +- `lein uberjar` +- `java -Xms1024m -Xmx2048m -jar target/cnn-text-classification-0.1.0-SNAPSHOT-standalone.jar` diff --git a/contrib/clojure-package/examples/cnn-text-classification/get_data.sh b/contrib/clojure-package/examples/cnn-text-classification/get_data.sh new file mode 100755 index 000000000000..7bbd9ce72142 --- /dev/null +++ b/contrib/clojure-package/examples/cnn-text-classification/get_data.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +mkdir -p data/mr-data +cd data/mr-data +wget https://raw.githubusercontent.com/yoonkim/CNN_sentence/master/rt-polarity.neg +wget https://raw.githubusercontent.com/yoonkim/CNN_sentence/master/rt-polarity.pos +cd ../.. +mkdir -p data/glove +cd data/glove +wget http://nlp.stanford.edu/data/glove.6B.zip +unzip *.zip +cd ../.. diff --git a/contrib/clojure-package/examples/cnn-text-classification/project.clj b/contrib/clojure-package/examples/cnn-text-classification/project.clj new file mode 100644 index 000000000000..f3eb21ab547a --- /dev/null +++ b/contrib/clojure-package/examples/cnn-text-classification/project.clj @@ -0,0 +1,23 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject cnn-text-classification "0.1.0-SNAPSHOT" + :description "CNN text classification with MXNet" + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"]] + :main cnn-text-classification.classifier + :pedantic? :skip) diff --git a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj new file mode 100644 index 000000000000..756328caf7a4 --- /dev/null +++ b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj @@ -0,0 +1,114 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns cnn-text-classification.classifier + (:require [cnn-text-classification.data-helper :as data-helper] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.context :as context]) + (:gen-class)) + +(def mr-dataset-path "data/mr-data") ;; the MR polarity dataset path +(def glove-file-path "data/glove/glove.6B.50d.txt") +(def num-filter 100) +(def num-label 2) +(def dropout 0.5) + +(defn shuffle-data [test-num {:keys [data label sentence-count sentence-size embedding-size]}] + (println "Shuffling the data and splitting into training and test sets") + (println {:sentence-count sentence-count + :sentence-size sentence-size + :embedding-size embedding-size}) + (let [shuffled (shuffle (map #(vector %1 %2) data label)) + train-num (- (count shuffled) test-num) + training (into [] (take train-num shuffled)) + test (into [] (drop train-num shuffled))] + {:training {:data (ndarray/array (into [] (flatten (mapv first training))) + [train-num 1 sentence-size embedding-size]) ;; has to be channel x y + :label (ndarray/array (into [] (flatten (mapv last training))) + [train-num])} + :test {:data (ndarray/array (into [] (flatten (mapv first test))) + [test-num 1 sentence-size embedding-size]) ;; has to be channel x y + :label (ndarray/array (into [] (flatten (mapv last test))) + [test-num])}})) + +(defn make-filter-layers [{:keys [input-x num-embed sentence-size] :as config} + filter-size] + (as-> (sym/convolution {:data input-x + :kernel [filter-size num-embed] + :num-filter num-filter}) data + (sym/activation {:data data :act-type "relu"}) + (sym/pooling {:data data + :pool-type "max" + :kernel [(inc (- sentence-size filter-size)) 1] + :stride [1 1]}))) + +;;; convnet with multiple filter sizes +;; from Convolutional Neural Networks for Sentence Classification by Yoon Kim +(defn get-multi-filter-convnet [num-embed sentence-size batch-size] + (let [filter-list [3 4 5] + input-x (sym/variable "data") + polled-outputs (mapv #(make-filter-layers {:input-x input-x :num-embed num-embed :sentence-size sentence-size} %) filter-list) + total-filters (* num-filter (count filter-list)) + concat (sym/concat "concat" nil polled-outputs {:dim 1}) + hpool (sym/reshape "hpool" {:data concat :target-shape [batch-size total-filters]}) + hdrop (if (pos? dropout) (sym/dropout "hdrop" {:data hpool :p dropout}) hpool) + fc (sym/fully-connected "fc1" {:data hdrop :num-hidden num-label})] + (sym/softmax-output "softmax" {:data fc}))) + +(defn train-convnet [{:keys [devs embedding-size batch-size test-size num-epoch max-examples]}] + (let [glove (data-helper/load-glove glove-file-path) ;; you can also use word2vec + ms-dataset (data-helper/load-ms-with-embeddings mr-dataset-path embedding-size glove max-examples) + sentence-size (:sentence-size ms-dataset) + shuffled (shuffle-data test-size ms-dataset) + train-data (mx-io/ndarray-iter [(get-in shuffled [:training :data])] + {:label [(get-in shuffled [:training :label])] + :label-name "softmax_label" + :data-batch-size batch-size + :last-batch-handle "pad"}) + test-data (mx-io/ndarray-iter [(get-in shuffled [:test :data])] + {:label[(get-in shuffled [:test :label])] + :label-name "softmax_label" + :data-batch-size batch-size + :last-batch-handle "pad"})] + (let [mod (m/module (get-multi-filter-convnet embedding-size sentence-size batch-size) {:contexts devs})] + (println "Getting ready to train for " num-epoch " epochs") + (println "===========") + (m/fit mod {:train-data train-data :eval-data test-data :num-epoch num-epoch + :fit-params (m/fit-params {:optimizer (optimizer/adam)})})))) + +(defn -main [& args] + (let [[dev dev-num] args + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range (Integer/parseInt (or dev-num "1")))) + (mapv #(context/cpu %) (range (Integer/parseInt (or dev-num "1")))))] + ;;; omit max-examples if you want to run all the examples in the movie review dataset + ;; to limit mem consumption set to something like 1000 and adjust test size to 100 + (println "Running with context devices of" devs) + (train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}) + ;; runs all the examples + #_(train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10}))) + + +(comment + (train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}) + ) + diff --git a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj new file mode 100644 index 000000000000..e7a706fb03e4 --- /dev/null +++ b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj @@ -0,0 +1,152 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns cnn-text-classification.data-helper + (:require [clojure.java.io :as io] + [clojure.string :as string] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.random :as random]) + (:import (java.io DataInputStream)) + (:gen-class)) + +(def w2v-file-path "../../data/GoogleNews-vectors-negative300.bin") ;; the word2vec file path +(def max-vectors 100) ;; If you are using word2vec embeddings and you want to only load part of them + +(defn r-string [dis] + (let [max-size 50 + bs (byte-array max-size) + sb (new StringBuilder)] + (loop [b (.readByte dis) + i 0] + (if (and (not= 32 b) (not= 10 b)) + (do (aset bs i b) + (if (= 49 i) + (do (.append sb (new String bs)) + (recur (.readByte dis) 0)) + (recur (.readByte dis) (inc i)))) + (.append sb (new String bs 0 i)))) + (.toString sb))) + + +(defn get-float [b] + (-> 0 + (bit-or (bit-shift-left (bit-and (aget b 0) 0xff) 0)) + (bit-or (bit-shift-left (bit-and (aget b 1) 0xff) 8)) + (bit-or (bit-shift-left (bit-and (aget b 2) 0xff) 16)) + (bit-or (bit-shift-left (bit-and (aget b 3) 0xff) 24)))) + +(defn read-float [is] + (let [bs (byte-array 4)] + (do (.read is bs) + (get-float bs)))) + +(defn load-google-model [path] + (println "Loading the word2vec model from binary ...") + (with-open [bis (io/input-stream path) + dis (new DataInputStream bis)] + (let [word-size (Integer/parseInt (r-string dis)) + dim (Integer/parseInt (r-string dis)) + _ (println "Processing with " {:dim dim :word-size word-size} " loading max vectors " max-vectors) + word2vec (reduce (fn [r _] + (assoc r (r-string dis) + (mapv (fn [_] (read-float dis)) (range dim)))) + {} + (range max-vectors))] + (println "Finished") + {:num-embed dim :word2vec word2vec}))) + + +(defn clean-str [s] + (-> s + (string/replace #"^A-Za-z0-9(),!?'`]" " ") + (string/replace #"'s" " 's") + (string/replace #"'ve" " 've") + (string/replace #"n't" " n't") + (string/replace #"'re" " 're") + (string/replace #"'d" " 'd") + (string/replace #"'ll" " 'll") + (string/replace #"," " , ") + (string/replace #"!" " ! ") + (string/replace #"\(" " ( ") + (string/replace #"\)" " ) ") + (string/replace #"\?" " ? ") + (string/replace #" {2,}" " ") + (string/trim))) + + + ;; Loads MR polarity data from files, splits the data into words and generates labels. + ;; Returns split sentences and labels. +(defn load-mr-data-and-labels [path max-examples] + (println "Loading all the movie reviews from " path) + (let [positive-examples (mapv #(string/trim %) (-> (slurp (str path "/rt-polarity.pos")) + (string/split #"\n"))) + negative-examples (mapv #(string/trim %) (-> (slurp (str path "/rt-polarity.neg")) + (string/split #"\n"))) + positive-examples (into [] (if max-examples (take max-examples positive-examples) positive-examples)) + negative-examples (into [] (if max-examples (take max-examples negative-examples) negative-examples)) + ;; split by words + x-text (->> (into positive-examples negative-examples) + (mapv clean-str) + (mapv #(string/split % #" "))) + + ;; generate labels + positive-labels (mapv (constantly 1) positive-examples) + negative-labels (mapv (constantly 0) negative-examples)] + {:sentences x-text :labels (into positive-labels negative-labels)})) + +;; Pads all sentences to the same length. The length is defined by the longest sentence. +;; Returns padded sentences. +(defn pad-sentences [sentences] + (let [padding-word "" + sequence-len (apply max (mapv count sentences))] + (mapv (fn [s] (let [diff (- sequence-len (count s))] + (if (pos? diff) + (into s (repeat diff padding-word)) + s))) + sentences))) + + + ;; Map sentences and labels to vectors based on a pretrained embeddings +(defn build-input-data-with-embeddings [sentences embedding-size embeddings] + (mapv (fn [sent] + (mapv (fn [word] (or (get embeddings word) + (ndarray/->vec (random/uniform -0.25 0.25 [embedding-size])))) + sent)) + sentences)) + +(defn load-ms-with-embeddings [path embedding-size embeddings max-examples] + (println "Translating the movie review words into the embeddings") + (let [{:keys [sentences labels]} (load-mr-data-and-labels path max-examples) + sentences-padded (pad-sentences sentences) + data (build-input-data-with-embeddings sentences-padded embedding-size embeddings)] + {:data data + :label labels + :sentence-count (count data) + :sentence-size (count (first data)) + :embedding-size embedding-size})) + +(defn read-text-embedding-pairs [rdr] + (for [^String line (line-seq rdr) + :let [fields (.split line " ")]] + [(aget fields 0) + (mapv #(Double/parseDouble ^String %) (rest fields))])) + +(defn load-glove [glove-file-path] + (println "Loading the glove pre-trained word embeddings from " glove-file-path) + (into {} (read-text-embedding-pairs (io/reader glove-file-path)))) + diff --git a/contrib/clojure-package/examples/gan/.gitignore b/contrib/clojure-package/examples/gan/.gitignore new file mode 100644 index 000000000000..c53038ec0e3d --- /dev/null +++ b/contrib/clojure-package/examples/gan/.gitignore @@ -0,0 +1,11 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ diff --git a/contrib/clojure-package/examples/gan/README.md b/contrib/clojure-package/examples/gan/README.md new file mode 100644 index 000000000000..2b46a6cf3e83 --- /dev/null +++ b/contrib/clojure-package/examples/gan/README.md @@ -0,0 +1,16 @@ +# gan + +This is an example of how to do a GAN with the MNIST data + +## Usage + +Do `lein run` and the images generated will be in the `results` directory. The gout* images are the ones generated, the diff* images are the visualization of the input gradient different fed to the generator + +`lein run :gpu` will run on gpu + +If you are running on AWS you will need to setup X11 for graphics +`sudo apt install xauth x11-apps` + +then relogin in `ssh -X -i creds ubuntu@yourinstance` + + diff --git a/contrib/clojure-package/examples/gan/project.clj b/contrib/clojure-package/examples/gan/project.clj new file mode 100644 index 000000000000..1469f9f8f279 --- /dev/null +++ b/contrib/clojure-package/examples/gan/project.clj @@ -0,0 +1,23 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject gan "0.1.0-SNAPSHOT" + :description "GAN MNIST with MXNet" + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"] + [nu.pattern/opencv "2.4.9-7"]] + :main gan.gan-mnist) diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj new file mode 100644 index 000000000000..42c116a1bf78 --- /dev/null +++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj @@ -0,0 +1,225 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns gan.gan-mnist + (:require [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.initializer :as init] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as opt] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util] + [gan.viz :as viz] + [org.apache.clojure-mxnet.context :as context]) + (:gen-class)) + + +;; based off of https://medium.com/@julsimon/generative-adversarial-networks-on-apache-mxnet-part-1-b6d39e6b5df1 + + +(def data-dir "data/") +(def output-path "results/") +(def batch-size 100) +(def num-epoch 10) + +(io/make-parents (str output-path "gout")) + +(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte"))) + (sh "../../scripts/get_mnist_data.sh")) + +(defonce mnist-iter (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") + :label (str data-dir "train-labels-idx1-ubyte") + :input-shape [1 28 28] + :batch-size batch-size + :shuffle true})) + +(def rand-noise-iter (mx-io/rand-iter [batch-size 100 1 1])) + +(comment + + ;;This is for figuring out the convolution and deconvolution layers to convert the image sizes + + (defn conv-output-size [input-size kernel-size padding stride] + (float (inc (/ (- (+ input-size (* 2 padding)) kernel-size) stride)))) + + ;; Calcing the layer sizes for discriminator + (conv-output-size 28 4 3 2) ;=> 16 + (conv-output-size 16 4 1 2) ;=> 8 + (conv-output-size 8 4 1 2) ;=> 4.0 + (conv-output-size 4 4 0 1) ;=> 1 + + ;; Calcing the layer sizes for generator + (defn deconv-output-size [input-size kernel-size padding stride] + (- + (+ (* stride (- input-size 1)) + kernel-size) + (* 2 padding))) + + (deconv-output-size 1 4 0 1) ;=> 4 + (deconv-output-size 4 4 1 2) ;=> 8 + (deconv-output-size 8 4 1 2) ;=> 16 + (deconv-output-size 16 4 3 2)) ;=> 28 + + +(def ndf 28) ;; image height /width +(def nc 1) ;; number of channels +(def eps (float (+ 1e-5 1e-12))) +(def lr 0.0005) ;; learning rate +(def beta1 0.5) + + +(def label (sym/variable "label")) + +(defn discriminator [] + (as-> (sym/variable "data") data + (sym/convolution "d1" {:data data :kernel [4 4] :pad [3 3] :stride [2 2] :num-filter ndf :no-bias true}) + (sym/batch-norm "dbn1" {:data data :fix-gamma true :eps eps}) + (sym/leaky-re-lu "dact1" {:data data :act-type "leaky" :slope 0.2}) + + (sym/convolution "d2" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter (* 2 ndf) :no-bias true}) + (sym/batch-norm "dbn2" {:data data :fix-gamma true :eps eps}) + (sym/leaky-re-lu "dact1" {:data data :act_type "leaky" :slope 0.2}) + + (sym/convolution "d3" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter (* 3 ndf) :no-bias true}) + (sym/batch-norm "dbn3" {:data data :fix-gamma true :eps eps}) + (sym/leaky-re-lu "dact3" {:data data :act_type "leaky" :slope 0.2}) + + (sym/convolution "d4" {:data data :kernel [4 4] :pad [0 0] :stride [1 1] :num-filter (* 4 ndf) :no-bias true}) + (sym/flatten "flt" {:data data}) + + (sym/fully-connected "fc" {:data data :num-hidden 1 :no-bias false}) + (sym/logistic-regression-output "dloss" {:data data :label label}))) + +(defn generator [] + (as-> (sym/variable "rand") data + (sym/deconvolution "g1" {:data data :kernel [4 4] :pad [0 0] :stride [1 1] :num-filter (* 4 ndf) :no-bias true}) + (sym/batch-norm "gbn1" {:data data :fix-gamma true :eps eps}) + (sym/activation "gact1" {:data data :act-type "relu"}) + + (sym/deconvolution "g2" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter (* 2 ndf) :no-bias true}) + (sym/batch-norm "gbn2" {:data data :fix-gamma true :eps eps}) + (sym/activation "gact2" {:data data :act-type "relu"}) + + (sym/deconvolution "g3" {:data data :kernel [4 4] :pad [1 1] :stride [2 2] :num-filter ndf :no-bias true}) + (sym/batch-norm "gbn3" {:data data :fix-gamma true :eps eps}) + (sym/activation "gact3" {:data data :act-type "relu"}) + + (sym/deconvolution "g4" {:data data :kernel [4 4] :pad [3 3] :stride [2 2] :num-filter nc :no-bias true}) + (sym/activation "gact4" {:data data :act-type "tanh"}))) + +(let [data [(ndarray/ones [batch-size 100 1 1])] + label [(ndarray/ones [batch-size 100 1 1])]] + (def my-iter (mx-io/ndarray-iter data))) + + +(defn save-img-gout [i n x] + (do + (viz/im-sav {:title (str "gout-" i "-" n) + :output-path output-path + :x x + :flip false}))) + +(defn save-img-diff [i n x] + (do (viz/im-sav {:title (str "diff-" i "-" n) + :output-path output-path + :x x + :flip false}))) + +(defn save-img-data [i n batch] + (do (viz/im-sav {:title (str "data-" i "-" n) + :output-path output-path + :x (first (mx-io/batch-data batch)) + :flip false}))) + + +(defn calc-diff [i n diff-d] + (let [diff (ndarray/copy diff-d) + arr (ndarray/->vec diff) + mean (/ (apply + arr) (count arr)) + std (let [tmp-a (map #(* (- % mean) (- % mean)) arr)] + (float (Math/sqrt (/ (apply + tmp-a) (count tmp-a)))))] + (let [calc-diff (ndarray/+ (ndarray/div (ndarray/- diff mean) std) 0.5)] + + (save-img-diff i n calc-diff)))) + + + +(defn train [devs] + (let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]}) + (m/bind {:data-shapes (mx-io/provide-data mnist-iter) + :label-shapes (mx-io/provide-label mnist-iter) + :inputs-need-grad true}) + (m/init-params {:initializer (init/normal 0.02)}) + (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})})) + mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil}) + (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)}) + (m/init-params {:initializer (init/normal 0.02)}) + (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))] + + (println "Training for " num-epoch " epochs...") + (doseq [i (range num-epoch)] + (mx-io/reduce-batches mnist-iter + (fn [n batch] + (let [rbatch (mx-io/next rand-noise-iter) + out-g (-> mod-g + (m/forward rbatch) + (m/outputs)) + ;; update the discriminiator on the fake + grads-f (mapv #(ndarray/copy (first %)) (-> mod-d + (m/forward {:data (first out-g) :label [(ndarray/zeros [batch-size])]}) + (m/backward) + (m/grad-arrays))) + ;; update the discrimintator on the real + grads-r (-> mod-d + (m/forward {:data (mx-io/batch-data batch) :label [(ndarray/ones [batch-size])]}) + (m/backward) + (m/grad-arrays)) + _ (mapv (fn [real fake] (let [r (first real)] + (ndarray/set r (ndarray/+ r fake)))) grads-r grads-f) + _ (m/update mod-d) + ;; update the generator + diff-d (-> mod-d + (m/forward {:data (first out-g) :label [(ndarray/ones [batch-size])]}) + (m/backward) + (m/input-grads)) + _ (-> mod-g + (m/backward (first diff-d)) + (m/update))] + (when (zero? (mod n 100)) + (println "iteration = " i "number = " n) + (save-img-gout i n (ndarray/copy (ffirst out-g))) + (save-img-data i n batch) + (calc-diff i n (ffirst diff-d))) + (inc n))))))) + +(defn -main [& args] + (let [[dev dev-num] args + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range (Integer/parseInt (or dev-num "1")))) + (mapv #(context/cpu %) (range (Integer/parseInt (or dev-num "1")))))] + (println "Running with context devices of" devs) + (train devs))) + +(comment + (train [(context/cpu)]) + + ) diff --git a/contrib/clojure-package/examples/gan/src/gan/viz.clj b/contrib/clojure-package/examples/gan/src/gan/viz.clj new file mode 100644 index 000000000000..7fab13e38aeb --- /dev/null +++ b/contrib/clojure-package/examples/gan/src/gan/viz.clj @@ -0,0 +1,88 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns gan.viz + (:require [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.io :as mx-io]) + (:import (nu.pattern OpenCV) + (org.opencv.core Core CvType Mat Size) + (org.opencv.imgproc Imgproc) + (org.opencv.highgui Highgui))) + +;;; Viz stuff +(OpenCV/loadShared) + + +(defn clip [x] + (->> x + (mapv #(* 255 %)) + (mapv #(cond + (< % 0) 0 + (> % 255) 255 + :else (int %))) + (mapv #(.byteValue %)))) + +(defn get-img [raw-data channels height width flip] + (let [totals (* height width) + img (if (> channels 1) + ;; rgb image + (let [[ra ga ba] (byte-array (partition totals raw-data)) + rr (new Mat height width (CvType/CV_8U)) + gg (new Mat height width (CvType/CV_8U)) + bb (new Mat height width (CvType/CV_8U)) + result (new Mat)] + (.put rr (int 0) (int 0) ra) + (.put gg (int 0) (int 0) ga) + (.put bb (int 0) (int 0) ba) + (Core/merge (java.util.ArrayList. [bb gg rr]) result) + result) + ;; gray image + (let [result (new Mat height width (CvType/CV_8U)) + _ (.put result (int 0) (int 0) (byte-array raw-data))] + result))] + (do + (if flip + (let [result (new Mat) + _ (Core/flip img result (int 0))] + result) + img)))) + +(defn im-sav [{:keys [title output-path x flip] + :or {flip false} :as g-mod}] + (let [shape (mx-shape/->vec (ndarray/shape x)) + _ (assert (== 4 (count shape))) + [n c h w] shape + totals (* h w) + raw-data (byte-array (clip (ndarray/to-array x))) + row (.intValue(Math/sqrt n)) + col row + line-arrs (into [] (partition (* col c totals) raw-data)) + line-mats (mapv (fn [line] + (let [img-arr (into [] (partition (* c totals) line)) + col-mats (new Mat) + src (mapv (fn [arr] (get-img (into [] arr) c h w flip)) img-arr) + _ (Core/hconcat (java.util.ArrayList. src) col-mats)] + col-mats)) + line-arrs) + result (new Mat) + resized-img (new Mat) + _ (Core/vconcat (java.util.ArrayList. line-mats) result)] + (do + (Imgproc/resize result resized-img (new Size (* (.width result) 1.5) (* (.height result) 1.5))) + (Highgui/imwrite (str output-path title ".jpg") resized-img) + (Thread/sleep 1000)))) diff --git a/contrib/clojure-package/examples/imclassification/.gitignore b/contrib/clojure-package/examples/imclassification/.gitignore new file mode 100644 index 000000000000..c53038ec0e3d --- /dev/null +++ b/contrib/clojure-package/examples/imclassification/.gitignore @@ -0,0 +1,11 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ diff --git a/contrib/clojure-package/examples/imclassification/README.md b/contrib/clojure-package/examples/imclassification/README.md new file mode 100644 index 000000000000..4677f289d7e0 --- /dev/null +++ b/contrib/clojure-package/examples/imclassification/README.md @@ -0,0 +1,18 @@ +# imclassification + +This shows off how to do image classification with the module api + +There is an example of the high level training api fit and also how to use multiple cpus/gpus + +To see more examples of how to use different parts of the module api look at the module example + +To run the example you must do + +* `lein install` in the root of the main project directory +* cd into this project directory and do `lein run`. This will execute the cpu version. + +You can control the devices you run on by doing: + +`lein run :cpu 2` - This will run on 2 cpu devices +`lein run :gpu 1` - This will run on 1 gpu device +`lein run :gpu 2` - This will run on 2 gpu devices diff --git a/contrib/clojure-package/examples/imclassification/project.clj b/contrib/clojure-package/examples/imclassification/project.clj new file mode 100644 index 000000000000..5c22d86cfcf7 --- /dev/null +++ b/contrib/clojure-package/examples/imclassification/project.clj @@ -0,0 +1,23 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject imclassification "0.1.0-SNAPSHOT" + :description "Clojure examples for image classification" + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"]] + :main imclassification.train-mnist + :pedantic? :skip) diff --git a/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj b/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj new file mode 100644 index 000000000000..44d3364438df --- /dev/null +++ b/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj @@ -0,0 +1,115 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns imclassification.train-mnist + (:require [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.kvstore :as kvstore] + [org.apache.clojure-mxnet.kvstore-server :as kvstore-server] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.eval-metric :as eval-metric]) + (:gen-class)) + +(def data-dir "data/") ;; the data directory to store the mnist data +(def batch-size 10) ;; the batch size +(def optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.0})) +(def eval-metric (eval-metric/accuracy)) +(def num-epoch 5) ;; the number of training epochs +(def kvstore "local") ;; the kvstore type +;;; Note to run distributed you might need to complile the engine with an option set +(def role "worker") ;; scheduler/server/worker +(def scheduler-host nil) ;; scheduler hostame/ ip address +(def scheduler-port 0) ;; scheduler port +(def num-workers 1) ;; # of workers +(def num-servers 1) ;; # of servers + + +(def envs (cond-> {"DMLC_ROLE" role} + scheduler-host (merge {"DMLC_PS_ROOT_URI" scheduler-host + "DMLC_PS_ROOT_PORT" (str scheduler-port) + "DMLC_NUM_WORKER" (str num-workers) + "DMLC_NUM_SERVER" (str num-servers)}))) + +(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte"))) + (sh "../../scripts/get_mnist_data.sh")) + +;;; Load the MNIST datasets +(defonce train-data (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") + :label (str data-dir "train-labels-idx1-ubyte") + :label-name "softmax_label" + :input-shape [784] + :batch-size batch-size + :shuffle true + :flat true + :silent false + :seed 10 + :num-parts num-workers + :part-index 0})) + +(defonce test-data (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte") + :label (str data-dir "t10k-labels-idx1-ubyte") + :input-shape [784] + :batch-size batch-size + :flat true + :silent false + :num-parts num-workers + :part-index 0})) + +(defn get-symbol [] + (as-> (sym/variable "data") data + (sym/fully-connected "fc1" {:data data :num-hidden 128}) + (sym/activation "relu1" {:data data :act-type "relu"}) + (sym/fully-connected "fc2" {:data data :num-hidden 64}) + (sym/activation "relu2" {:data data :act-type "relu"}) + (sym/fully-connected "fc3" {:data data :num-hidden 10}) + (sym/softmax-output "softmax" {:data data}))) + +(defn start [devs] + (when scheduler-host + (println "Initing PS enviornments with " envs) + (kvstore-server/init envs)) + + (if (not= "worker" role) + (do + (println "Start KVStoreServer for scheduler and servers") + (kvstore-server/start)) + (do + (println "Starting Training of MNIST ....") + (println "Running with context devices of" devs) + (let [mod (m/module (get-symbol) {:contexts devs})] + (m/fit mod {:train-data train-data + :eval-data test-data + :num-epoch num-epoch + :fit-params (m/fit-params {:kvstore kvstore + :optimizer optimizer + :eval-metric eval-metric})})) + (println "Finish fit")))) + +(defn -main [& args] + (let [[dev dev-num] args + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range (Integer/parseInt (or dev-num "1")))) + (mapv #(context/cpu %) (range (Integer/parseInt (or dev-num "1")))))] + (start devs))) + +(comment + (start [(context/cpu)]) + ) diff --git a/contrib/clojure-package/examples/module/README.md b/contrib/clojure-package/examples/module/README.md new file mode 100644 index 000000000000..1b08a52f7db3 --- /dev/null +++ b/contrib/clojure-package/examples/module/README.md @@ -0,0 +1,21 @@ +## Instructions + +This shows off how to use the module api. + +There are examples of: + - high level api of training and prediction + - intermediate level api with save and loading from checkpoints + - examples of how to iteratate through the batch and calculate accuracy and predict manually. + +To run the example you must do + +* `lein install` in the root of the main project directory +* cd into this project directory and do `lein run`. This will execute the cpu version. + +You can control the devices you run on by doing: + +`lein run :cpu 2` - This will run on 2 cpu devices +`lein run :gpu 1` - This will run on 1 gpu device +`lein run :gpu 2` - This will run on 2 gpu devices + + diff --git a/contrib/clojure-package/examples/module/project.clj b/contrib/clojure-package/examples/module/project.clj new file mode 100644 index 000000000000..2cd979642893 --- /dev/null +++ b/contrib/clojure-package/examples/module/project.clj @@ -0,0 +1,24 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject module-examples "0.1.0-SNAPSHOT" + :description "Clojure examples for module" + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"]] + :main mnist-mlp + :pedantic? :skip) + diff --git a/contrib/clojure-package/examples/module/src/mnist_mlp.clj b/contrib/clojure-package/examples/module/src/mnist_mlp.clj new file mode 100644 index 000000000000..039a6ebc6128 --- /dev/null +++ b/contrib/clojure-package/examples/module/src/mnist_mlp.clj @@ -0,0 +1,240 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns mnist-mlp + (:require [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.ndarray :as ndarray]) + (:gen-class)) + +(def data-dir "data/") +(def batch-size 10) +(def num-epoch 5) + +(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte"))) + (sh "../../scripts/get_mnist_data.sh")) +;; for save checkpoints load checkpoints +(io/make-parents "model/dummy.txt") + +;;; Load the MNIST datasets +(defonce train-data (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") + :label (str data-dir "train-labels-idx1-ubyte") + :label-name "softmax_label" + :input-shape [784] + :batch-size batch-size + :shuffle true + :flat true + :silent false + :seed 10})) + +(defonce test-data (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte") + :label (str data-dir "t10k-labels-idx1-ubyte") + :input-shape [784] + :batch-size batch-size + :flat true + :silent false})) +(defn get-symbol [] + (as-> (sym/variable "data") data + (sym/fully-connected "fc1" {:data data :num-hidden 128}) + (sym/activation "relu1" {:data data :act-type "relu"}) + (sym/fully-connected "fc2" {:data data :num-hidden 64}) + (sym/activation "relu2" {:data data :act-type "relu"}) + (sym/fully-connected "fc3" {:data data :num-hidden 10}) + (sym/softmax-output "softmax" {:data data}))) + +(defn- print-header [message] + (println "") + (println "=================") + (println (str " " message)) + (println "=================") + (println "")) + +(defn run-intermediate-level-api [& {:keys [devs load-model-epoch]}] + + (let [header "Running Intermediate Level API"] + (print-header (if load-model-epoch (str header " and loading from previous epoch " load-model-epoch) + header))) + + (let [save-prefix "model/mnist-mlp" + mod (if load-model-epoch + (do + (println "Loading from checkpoint of epoch " load-model-epoch) + (m/load-checkpoint {:contexts devs :prefix save-prefix :epoch load-model-epoch})) + (m/module (get-symbol) {:contexts devs})) + metric (eval-metric/accuracy)] + (-> mod + (m/bind {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)}) + (m/init-params) + (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})})) + + (doseq [epoch-num (range num-epoch)] + (println "starting epoch " epoch-num) + (mx-io/do-batches + train-data + (fn [batch] + (-> mod + (m/forward batch) + (m/update-metric metric (mx-io/batch-label batch)) + (m/backward) + (m/update)))) + (println "result for epoch " epoch-num " is " (eval-metric/get-and-reset metric)) + (m/save-checkpoint mod {:prefix save-prefix :epoch epoch-num :save-opt-states true})))) + +(defn run-high-level-api [devs] + (print-header "Running High Level API") + + (let [mod (m/module (get-symbol) {:contexts devs})] + ;;; note only one function for training + (m/fit mod {:train-data train-data :eval-data test-data :num-epoch num-epoch}) + + ;;high level predict (just a dummy call but it returns a vector of results + (m/predict mod {:eval-data test-data}) + + ;;;high level score (returs the eval values) + (let [score (m/score mod {:eval-data test-data :eval-metric (eval-metric/accuracy)})] + (println "High level predict score is " score)))) + + +(defn run-predication-and-calc-accuracy-manually [devs] + ;;; Gathers all the predictions at once with `predict-every-batch` + ;;; then cycles thorugh the batches and manually calculates the accuracy stats + + (print-header "Running Predicting and Calcing the Accuracy Manually") + + (let [mod (m/module (get-symbol) {:contexts devs})] + ;;; note only one function for training + (m/fit mod {:train-data train-data :eval-data test-data :num-epoch num-epoch}) + (let [preds (m/predict-every-batch mod {:eval-data test-data}) + stats (mx-io/reduce-batches test-data + (fn [r b] + (let [pred-label (->> (ndarray/argmax-channel (first (get preds (:index r)))) + (ndarray/->vec) + (mapv int)) + label (->> (mx-io/batch-label b) + (first) + (ndarray/->vec) + (mapv int)) + acc-sum (apply + (mapv (fn [pl l] (if (= pl l) 1 0)) + pred-label label))] + (-> r + (update :index inc) + (update :acc-cnt (fn [v] (+ v (count pred-label)))) + (update :acc-sum (fn [v] (+ v + (apply + (mapv (fn [pl l] (if (= pl l) 1 0)) + pred-label label)))))))) + {:acc-sum 0 :acc-cnt 0 :index 0})] + (println "Stats: " stats) + (println "Accuracy: " (/ (:acc-sum stats) + (* 1.0 (:acc-cnt stats))))))) + +(defn run-prediction-iterator-api [devs] + ;;Cycles through all the batchs and manually predicts and prints out the accuracy + ;;using `predict-batch` + + (print-header "Running the Prediction Iterator API and Calcing the Accuracy Manually") + + (let [mod (m/module (get-symbol) {:contexts devs})] + ;;; note only one function for training + (m/fit mod {:train-data train-data :eval-data test-data :num-epoch num-epoch}) + (mx-io/reduce-batches test-data + (fn [r b] + (let [preds (m/predict-batch mod b) + pred-label (->> (ndarray/argmax-channel (first preds)) + (ndarray/->vec) + (mapv int)) + label (->> (mx-io/batch-label b) + (first) + (ndarray/->vec) + (mapv int)) + acc (/ (apply + (mapv (fn [pl l] (if (= pl l) 1 0)) pred-label label)) + (* 1.0 (count pred-label)))] + (println "Batch " r " acc: " acc) + (inc r)))))) + +(defn run-all [devs] + (run-intermediate-level-api :devs devs) + (run-intermediate-level-api :devs devs :load-model-epoch (dec num-epoch)) + (run-high-level-api devs) + (run-prediction-iterator-api devs) + (run-predication-and-calc-accuracy-manually devs)) + +(defn -main + [& args] + (let [[dev dev-num] args + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range (Integer/parseInt (or dev-num "1")))) + (mapv #(context/cpu %) (range (Integer/parseInt (or dev-num "1")))))] + (println "Running Module MNIST example") + (println "Running with context devices of" devs) + (run-all devs))) + + +(comment + + ;;; run all the example functions + (run-all [(context/cpu)]) + + ;;; run for the number of epochs + (run-intermediate-level-api :devs [(context/cpu)]) + ;;=> starting epoch 0 + ;;=> result for epoch 0 is [accuracy 0.8531333] + ;;=> INFO ml.dmlc.mxnet.module.Module: Saved checkpoint to model/mnist-mlp-0000.params + ;;=> INFO ml.dmlc.mxnet.module.Module: Saved optimizer state to model/mnist-mlp-0000.states + ;;=> .... + ;;=> starting epoch 4 + ;;=> result for epoch 4 is [accuracy 0.91875] + ;;=> INFO ml.dmlc.mxnet.module.Module: Saved checkpoint to model/mnist-mlp-0004.params + ;;=> INFO ml.dmlc.mxnet.module.Module: Saved optimizer state to model/mnist-mlp-0004.states + + + ;; load from the last saved file and run again + (run-intermediate-level-api :devs [(context/cpu)] :load-model-epoch (dec num-epoch)) + ;;=> Loading from checkpoint of epoch 4 + ;;=> starting epoch 0 + ;;=> result for epoch 0 is [accuracy 0.96258336] + ;;=> INFO ml.dmlc.mxnet.module.Module: Saved checkpoint to model/mnist-mlp-0000.params + ;;=> INFO ml.dmlc.mxnet.module.Module: Saved optimizer state to model/mnist-mlp-0000.states + ;;=> ... + ;;=> starting epoch 4 + ;;=> result for epoch 4 is [accuracy 0.9819833] + ;;=> INFO ml.dmlc.mxnet.module.Module: Saved checkpoint to model/mnist-mlp-0004.params + ;;=> INFO ml.dmlc.mxnet.module.Module: Saved optimizer state to model/mnist-mlp-0004.states + + (run-high-level-api [(context/cpu)]) + ;;=> ["accuracy" 0.9454] + + (run-prediction-iterator-api [(context/cpu)]) + ;;=> Batch 0 acc: 1.0 + ;;=> Batch 1 acc: 0.9 + ;;=> Batch 2 acc: 1.0 + ;;=> ... + ;;=> Batch 999 acc: 1.0 + + (run-predication-and-calc-accuracy-manually [(context/cpu)]) + ;;=> Stats: {:acc-sum 9494, :acc-cnt 10000, :index 1000} + ;;=> Accuracy: 0.9494 +) + + diff --git a/contrib/clojure-package/examples/multi-label/.gitignore b/contrib/clojure-package/examples/multi-label/.gitignore new file mode 100644 index 000000000000..c53038ec0e3d --- /dev/null +++ b/contrib/clojure-package/examples/multi-label/.gitignore @@ -0,0 +1,11 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ diff --git a/contrib/clojure-package/examples/multi-label/README.md b/contrib/clojure-package/examples/multi-label/README.md new file mode 100644 index 000000000000..27a8c1ff01ff --- /dev/null +++ b/contrib/clojure-package/examples/multi-label/README.md @@ -0,0 +1,18 @@ +# multi-label + +This is a quick example of doing multi-label classification. +It involves using a proxy to implement the DataIter to make a custom +data iterator for MNIST + +To run +`lein run`. This will execute the cpu version. + +You can control the devices you run on by doing: + +`lein run :cpu` - This will run on 1 cpu device +`lein run :gpu` - This will run on 1 gpu device + +This example only works on 1 device + + + diff --git a/contrib/clojure-package/examples/multi-label/project.clj b/contrib/clojure-package/examples/multi-label/project.clj new file mode 100644 index 000000000000..b67178b25c40 --- /dev/null +++ b/contrib/clojure-package/examples/multi-label/project.clj @@ -0,0 +1,22 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject multi-label "0.1.0-SNAPSHOT" + :description "Example of multi-label classification" + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"]] + :main multi-label.core) diff --git a/contrib/clojure-package/examples/multi-label/src/multi_label/core.clj b/contrib/clojure-package/examples/multi-label/src/multi_label/core.clj new file mode 100644 index 000000000000..0707d4293f5e --- /dev/null +++ b/contrib/clojure-package/examples/multi-label/src/multi_label/core.clj @@ -0,0 +1,169 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns multi-label.core + (:require [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.context :as context]) + (:import (org.apache.mxnet DataIter) + (java.util NoSuchElementException)) + (:gen-class)) + + +(def data-dir "data/") +(def batch-size 100) +(def num-epoch 1) + +(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte"))) + (sh "../../scripts/get_mnist_data.sh")) + +;;; Load the MNIST datasets +(defonce train-data (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") + :label (str data-dir "train-labels-idx1-ubyte") + :label-name "softmax_label" + :input-shape [784] + :batch-size batch-size + :shuffle true + :flat true + :silent false + :seed 10})) + +(defonce test-data (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte") + :label (str data-dir "t10k-labels-idx1-ubyte") + :input-shape [784] + :batch-size batch-size + :flat true + :silent false})) +(defn build-network [] + (let [fc3 (as-> (sym/variable "data") data + (sym/fully-connected "fc1" {:data data :num-hidden 128}) + (sym/activation "relu1" {:data data :act-type "relu"}) + (sym/fully-connected "fc2" {:data data :num-hidden 64}) + (sym/activation "relu2" {:data data :act-type "relu"}) + (sym/fully-connected "fc3" {:data data :num-hidden 10})) + sm1 (sym/softmax-output "softmax1" {:data fc3}) + sm2 (sym/softmax-output "softmax2" {:data fc3})] + (sym/group [sm1 sm2]))) + +;;; provide an override proxy to the DataIter Scala class +(def multi-train-data (let [data-iter train-data] + (proxy [DataIter] [] + (hasNext [] + (mx-io/has-next? data-iter)) + (next [] + (if (mx-io/has-next? data-iter) + (let [batch (mx-io/next data-iter)] + (mx-io/data-batch {:data (util/scala-vector->vec + (.getData data-iter)) + :label (let [label (first + (util/scala-vector->vec (.getLabel data-iter)))] + [label label]) + :index (util/scala-vector->vec + (.getIndex data-iter)) + :pad (.pad batch)})) + (throw (new NoSuchElementException)))) + (reset [] + (mx-io/reset data-iter)) + (batchSize [] + (.batchSize data-iter)) + (getData [] + (.getData data-iter)) + (getLabel [] + (let [label (first (util/scala-vector->vec (.getLabel data-iter)))] (util/vec->indexed-seq [label label]))) + (getIndex [] + (.getIndex data-iter)) + (getPad [] + (.getPad data-iter)) + (provideLabel [] + (let [shape (->> (mx-io/provide-label data-iter) + (first) + (vals) + last)] + (util/list-map + {"softmax1_label" (mx-shape/->shape shape) + "softmax2_label" (mx-shape/->shape shape)}))) + (provideData [] + (.provideData data-iter))))) + +(defn train [devs] + (let [network (build-network) + data-and-labels (->> (into (mx-io/provide-data multi-train-data) + (mx-io/provide-label multi-train-data)) + (mapcat vals) + (apply hash-map)) + [arg-shapes output-shapes aux-shapes] (sym/infer-shape network data-and-labels) + arg-names (sym/list-arguments network) + aux-names (sym/list-auxiliary-states network) + arg-params (zipmap arg-names (mapv #(ndarray/empty %) arg-shapes)) + aux-params (zipmap aux-names (mapv #(ndarray/empty %) aux-shapes)) + metric (eval-metric/custom-metric + (fn [labels preds] + (println "Carin labels " labels) + (println "Carin preds " preds) + (float 0.5)) + "multi-accuracy") + mod (-> (m/module network {:contexts devs}) + (m/bind {:data-shapes (mx-io/provide-data multi-train-data) + :label-shapes (mx-io/provide-label multi-train-data)}) + (m/init-params {:arg-params arg-params :aux-params aux-params}) + (m/init-optimizer))] + (doseq [i (range 1)] + (println "Doing epoch " i) + (let [acc (mx-io/reduce-batches + multi-train-data + (fn [r b] + (let [labels (mx-io/batch-label b) + preds (-> (m/forward mod b) + (m/outputs)) + accs (mapv (fn [p l] + (let [pred-label (->> (ndarray/argmax-channel (first p)) + (ndarray/->vec) + (mapv int)) + label (->> (ndarray/->vec l) + (mapv int))] + (* 1.0 (apply + (mapv (fn [pl l] (if (= pl l) 1 0)) + pred-label label))))) + preds labels)] + (-> mod + (m/backward) + (m/update)) + (-> r + (update :sum #(mapv (fn [o n] (+ o n)) % accs)) + (update :batch-num inc)))) + {:sum [0 0] :batch-num 0})] + (println "Multi-accuracy " acc) + (println "Multi-accuracy "(mapv #(/ % (:batch-num acc)) (:sum acc))))))) + +(defn -main [& args] + (let [[dev dev-num] args + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range (Integer/parseInt (or dev-num "1")))) + (mapv #(context/cpu %) (range (Integer/parseInt (or dev-num "1")))))] + (println "Training...") + (println "Running with context devices of" devs) + (train devs))) + + +(comment + (train [(context/cpu)])) diff --git a/contrib/clojure-package/examples/neural-style/.gitignore b/contrib/clojure-package/examples/neural-style/.gitignore new file mode 100644 index 000000000000..4ec03eb2c0a3 --- /dev/null +++ b/contrib/clojure-package/examples/neural-style/.gitignore @@ -0,0 +1,13 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ +output/* +input/* diff --git a/contrib/clojure-package/examples/neural-style/README.md b/contrib/clojure-package/examples/neural-style/README.md new file mode 100644 index 000000000000..e05c31552dcd --- /dev/null +++ b/contrib/clojure-package/examples/neural-style/README.md @@ -0,0 +1,24 @@ +# neural-style + +An example of neural style transfer + +## Usage + +use the `download.sh` script to get the params file and the input and output file + +Then use `lein run` + +The output images will be stored in the output directory. Please feel free to play with the params at the top of the file + + +This example only works on 1 device (cpu) right now + +If you are running on AWS you will need to setup X11 for graphics +`sudo apt install xauth x11-apps` + +then relogin in `ssh -X -i creds ubuntu@yourinstance` + + +_Note: This example is not working all the way - it needs some debugging help_ + + diff --git a/contrib/clojure-package/examples/neural-style/download.sh b/contrib/clojure-package/examples/neural-style/download.sh new file mode 100755 index 000000000000..393d03b6163e --- /dev/null +++ b/contrib/clojure-package/examples/neural-style/download.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +mkdir -p model +cd model +wget https://github.com/dmlc/web-data/raw/master/mxnet/neural-style/model/vgg19.params +cd .. + +mkdir -p input +cd input +wget https://github.com/dmlc/web-data/raw/master/mxnet/neural-style/input/IMG_4343.jpg +wget https://github.com/dmlc/web-data/raw/master/mxnet/neural-style/input/starry_night.jpg +cd .. + +mkdir -p output diff --git a/contrib/clojure-package/examples/neural-style/project.clj b/contrib/clojure-package/examples/neural-style/project.clj new file mode 100644 index 000000000000..4daf20f8d094 --- /dev/null +++ b/contrib/clojure-package/examples/neural-style/project.clj @@ -0,0 +1,25 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject neural-style "0.1.0-SNAPSHOT" + :description "Neural Style Transfer with MXNet" + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"] + [net.mikera/imagez "0.12.0"] + [thinktopic/think.image "0.4.16"]] + :main neural-style.core +) diff --git a/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj b/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj new file mode 100644 index 000000000000..07b0a293243a --- /dev/null +++ b/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj @@ -0,0 +1,264 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns neural-style.core + (:require [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.lr-scheduler :as lr-scheduler] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as opt] + [org.apache.clojure-mxnet.random :as random] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.symbol :as sym] + [mikera.image.core :as img] + [mikera.image.filters :as img-filter] + [think.image.pixel :as pixel] + [neural-style.model-vgg-19 :as model-vgg-19]) + (:gen-class)) + + + ;; An Implementation of the paper A Neural Algorithm of Artistic Style + ;;by Leon A. Gatys, Alexander S. Ecker, and Matthias Bethge + +(def content-image "input/IMG_4343.jpg" ) +(def style-image "input/starry_night.jpg") +(def model-path "model/vgg19.params") +(def max-long-edge 600) ;; resize the content image +(def style-weight 1) ;; the weight for the style image +(def content-weight 5) ;; the weight for the content image +(def blur-radius 1) ;; the blur filter radius +(def output-dir "output") +(def lr 10) ;; the learning rate +(def tv-weight 0.01) ;; the magnitude on the tv loss +(def num-epochs 1000) +(def num-channels 3) + +(defn image->ndarray [simg] + (let [h (img/height simg) + w (img/width simg) + pixels (img/get-pixels simg) + ;; normalize the pixels for vgg19 + rgb-pixels (reduce (fn [result pixel] + (let [[rs gs bs] result + [r g b _] (pixel/unpack-pixel pixel)] + [(conj rs (- r 123.68)) + (conj gs (- g 116.779)) + (conj bs (- b 103.939))])) + [[] [] []] + pixels)] + (println "The resized image is size " {:height h :width w}) + (-> rgb-pixels + (flatten) + (ndarray/array [1 num-channels h w])))) + +(defn preprocess-content-image [path short-edge] + (let [simg (img/load-image path) + _ (println "The content image is size " {:height (img/height simg) :width (img/width simg)}) + factor (/ short-edge (img/width simg)) + resized-img (img/resize simg (* (img/width simg) factor) (* (img/height simg) factor) ) + new-height (img/height resized-img) + new-width (img/width resized-img)] + (image->ndarray resized-img))) + +(defn preprocess-style-image [path shape-vec] + (let [[_ _ h w] shape-vec + simg (img/load-image path) + _ (println "The image is size " {:height (img/height simg) :width (img/width simg)}) + resized-img (img/resize simg w h)] + (image->ndarray resized-img))) + +(defn postprocess-image [img] + (let [datas (ndarray/->vec img) + image-shape (mx-shape/->vec (ndarray/shape img)) + spatial-size (* (get image-shape 2) (get image-shape 3)) + [rs gs bs] (doall (partition spatial-size datas)) + pixels (mapv (fn [r g b] + (pixel/pack-pixel + (int (+ r 123.68)) + (int (+ g 116.779)) + (int (+ b 103.939)) + (int 255))) + rs gs bs) + new-image (img/new-image (get image-shape 3) (get image-shape 2)) + _ (img/set-pixels new-image (int-array pixels))] + new-image)) + + +(defn style-gram-symbol [input-size style] + (let [[_ output-shape _] (sym/infer-shape style {:data [1 3 (first input-size) (second input-size)]}) + output-shapes (mx-shape/->vec output-shape) + {:keys [gram-list grad-scale]} (doall (reduce + (fn [result i] + (let [shape (get output-shapes i) + [s0 s1 s2 s3] shape + x (sym/reshape {:data (sym/get style i) :target-shape [s1 (* s2 s3)] }) + ;; use fully connected to quickly do dot(x x^T) + gram (sym/fully-connected {:data x :weight x :no-bias true :num-hidden s1})] + (-> result + (update :gram-list conj gram) + (update :grad-scale conj (* s1 s2 s3 s1))))) + {:gram-list [] :grad-scale []} + (range (count (sym/list-outputs style)))))] + {:gram (sym/group (into [] gram-list)) :g-scale grad-scale})) + +(defn get-loss [gram content] + (let [gram-loss (doall (mapv (fn [i] + (let [gvar (sym/variable (str "target_gram_" i))] + (sym/sum (sym/square (sym/- gvar (sym/get gram i)))))) + (range (count (sym/list-outputs gram))))) + cvar (sym/variable "target_content") + content-loss (sym/sum (sym/square (sym/- cvar content)))] + {:style-loss (sym/group gram-loss) :content-loss content-loss})) + +(defn old-clip [v] + (mapv (fn [a] (cond + (neg? a) 0 + (> a 255) 255 + :else a)) + v)) + +(defn clip [a] + (cond + (neg? a) 0 + (> a 255) 255 + :else a)) + + +(defn save-image [img filename radius blur?] + (let [filtered-image (if blur? + ((img-filter/box-blur blur-radius blur-radius) (postprocess-image img)) + (postprocess-image img))] + (do + ;(img/show filtered-image) ;; Uncomment to have the image display + (img/write filtered-image filename "png")))) + +(defn get-tv-grad-executor [img ctx tv-weight] + (when (pos? tv-weight) + (let [img-shape (mx-shape/->vec (ndarray/shape img)) + n-channel(get img-shape 1) + s-img (sym/variable "img") + s-kernel (sym/variable "kernel") + channels (sym/split {:data s-img :axis 1 :num-outputs n-channel}) + out (sym/concat (doall (mapv (fn [i] + (sym/convolution {:data (sym/get channels i) :weight s-kernel + :num-filter 1 :kernel [3 3] :pad [1 1] :no-bias true :stride [1 1]})) + (range n-channel)))) + kernel (ndarray/* (ndarray/array [0 -1 0 -1 4 -1 0 -1 0] [1 1 3 3] {:ctx ctx}) + 0.8) + out (ndarray/* out tv-weight)] + (sym/bind out ctx {"img" img "kernel" kernel})))) + + +(defn train [devs] + + (let [dev (first devs) + content-np (preprocess-content-image content-image max-long-edge) + content-np-shape (mx-shape/->vec (ndarray/shape content-np)) + style-np (preprocess-style-image style-image content-np-shape) + size [(get content-np-shape 2) (get content-np-shape 3)] + {:keys [style content]} (model-vgg-19/get-symbol) + {:keys [gram g-scale]} (style-gram-symbol size style) + model-executor (model-vgg-19/get-executor gram content model-path size dev) + + _ (ndarray/set (:data model-executor) style-np) + _ (executor/forward (:executor model-executor)) + + style-array (mapv #(ndarray/copy %) (:style model-executor)) + + mode-executor nil + _ (ndarray/set (:data model-executor) content-np) + _ (executor/forward (:executor model-executor)) + content-array (ndarray/copy (:content model-executor)) + + {:keys [style-loss content-loss]} (get-loss gram content) + model-executor (model-vgg-19/get-executor style-loss content-loss model-path size dev) + + grad-array (-> (doall (mapv (fn [i] + (do + (ndarray/set (get (:arg-map model-executor) (str "target_gram_" i)) (get style-array i)) + (ndarray/* (ndarray/ones [1] {:ctx dev}) (/ style-weight (get g-scale i))))) + (range (count style-array)))) + (conj (ndarray/* (ndarray/ones [1] {:ctx dev}) content-weight))) + + _ (ndarray/copy-to content-array (get (:arg-map model-executor) "target_content")) + + ;;;train + + ;;initialize with random noise + img (ndarray/- (random/uniform 0 255 content-np-shape dev) 128) + ;;; img (random/uniform -0.1 0.1 content-np-shape dev) + ;; img content-np + lr-sched (lr-scheduler/factor-scheduler 10 0.9) + + _ (save-image content-np (str output-dir "/input.png") blur-radius false) + _ (save-image style-np (str output-dir "/style.png") blur-radius false) + + optimizer (opt/adam {:learning-rate lr + :wd 0.005 + :lr-scheduler lr-sched}) + optim-state (opt/create-state optimizer 0 img) + + _ (println "Starting training....") + old-img (ndarray/copy-to img dev) + clip-norm (apply * (mx-shape/->vec (ndarray/shape img))) + tv-grad-executor (get-tv-grad-executor img dev tv-weight) + eps 0.0 + e 0 ] + (doseq [i (range 20)] + (ndarray/set (:data model-executor) img) + (-> (:executor model-executor) + (executor/forward) + (executor/backward grad-array)) + + (let [g-norm (ndarray/to-scalar (ndarray/norm (:data-grad model-executor)))] + (if (> g-norm clip-norm) + (ndarray/set (:data-grad model-executor) (ndarray/* (:data-grad model-executor) (/ clip-norm g-norm))))) + + (if tv-grad-executor + (do + (executor/forward tv-grad-executor) + (opt/update optimizer 0 + img + (ndarray/+ (:data-grad model-executor) (first (executor/outputs tv-grad-executor))) + optim-state)) + (opt/update optimizer 0 img (:data-grad model-executor) optim-state)) + + (let [eps (ndarray/to-scalar + (ndarray/div (ndarray/norm (ndarray/- old-img img)) + (ndarray/norm img)))] + (println "Epoch " i "relative change " eps) + (when (zero? (mod i 2)) + (save-image (ndarray/copy img) (str output-dir "/out_" i ".png") blur-radius true))) + + (ndarray/set old-img img)))) + +(defn -main [& args] + ;;; Note this only works on cpu right now + (let [[dev dev-num] args + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range (Integer/parseInt (or dev-num "1")))) + (mapv #(context/cpu %) (range (Integer/parseInt (or dev-num "1")))))] + (println "Running with context devices of" devs) + (train devs))) + + + +(comment + + (train [(context/cpu)]) + + ) diff --git a/contrib/clojure-package/examples/neural-style/src/neural_style/model_vgg_19.clj b/contrib/clojure-package/examples/neural-style/src/neural_style/model_vgg_19.clj new file mode 100644 index 000000000000..5fa11be8851e --- /dev/null +++ b/contrib/clojure-package/examples/neural-style/src/neural_style/model_vgg_19.clj @@ -0,0 +1,99 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns neural-style.model-vgg-19 + (:require [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.symbol :as sym])) + +(defn get-symbol [] + (let [data (sym/variable "data") + conv1-1 (sym/convolution "conv1_1" {:data data :num-filter 64 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu1-1 (sym/activation "relu1_1" {:data conv1-1 :act-type "relu"}) + conv1-2 (sym/convolution "conv1-2" {:data relu1-1 :num-filter 64 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu1-2 (sym/activation "relu1_2" {:data conv1-2 :act-type "relu"}) + pool1 (sym/pooling "pool1" {:data relu1-2 :pad [0 0] :kernel [2 2] :stride [2 2] :pool-type "avg"}) + conv2-1 (sym/convolution "conv2_1" {:data pool1 :num-filter 128 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu2-1 (sym/activation "relu2_1" {:data conv2-1 :act-type "relu"}) + conv2-2 (sym/convolution "conv2_2" {:data relu2-1 :num-filter 128 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu2-2 (sym/activation "relu2_2" {:data conv2-2 :act-type "relu"}) + pool2 (sym/pooling "pool2" {:data relu2-2 :pad [0 0] :kernel [2 2] :stride [2 2] :pool-type "avg"}) + conv3-1 (sym/convolution "conv3_1" {:data pool2 :num-filter 256 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu3-1 (sym/activation "relu3_1" {:data conv3-1 :act-type "relu"}) + conv3-2 (sym/convolution "conv3_2" {:data relu3-1 :num-filter 256 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu3-2 (sym/activation "relu3_2" {:data conv3-2 :act-type "relu"}) + conv3-3 (sym/convolution "conv3_3" {:data relu3-2 :num-filter 256 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu3-3 (sym/activation "relu3_3" {:data conv3-3 :act-type "relu"}) + conv3-4 (sym/convolution "conv3_4" {:data relu3-3 :num-filter 256 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu3-4 (sym/activation "relu3_4" {:data conv3-4 :act-type "relu"}) + pool3 (sym/pooling "pool3" {:data relu3-4 :pad [0 0] :kernel [2 2] :stride [2 2] :pool-type "avg"}) + conv4-1 (sym/convolution "conv4_1" {:data pool3 :num-filter 512 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu4-1 (sym/activation "relu4_1" {:data conv4-1 :act-type "relu"}) + conv4-2 (sym/convolution "conv4_2" {:data relu4-1 :num-filter 512 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu4-2 (sym/activation "relu4_2" {:data conv4-2 :act-type "relu"}) + conv4-3 (sym/convolution "conv4_3" {:data relu4-2 :num-filter 512 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu4-3 (sym/activation "relu4_3" {:data conv4-3 :act-type "relu"}) + conv4-4 (sym/convolution "conv4_4" {:data relu4-3 :num-filter 512 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu4-4 (sym/activation "relu4_4" {:data conv4-4 :act-type "relu"}) + pool4 (sym/pooling "pool4" {:data relu4-4 :pad [0 0] :kernel [2 2] :stride [2 2] :pool-type "avg"}) + conv5-1 (sym/convolution "conv5_1" {:data pool4 :num-filter 512 :pad [1 1] :kernel [3 3] :stride [1 1] + :no-bias false :workspace 1024}) + relu5-1 (sym/activation "relu5_1" {:data conv5-1 :act-type "relu"}) + + ;;; style and content layers + style (sym/group [relu1-1 relu2-1 relu3-1 relu4-1 relu5-1]) + content (sym/group [relu1-1])] + {:style style :content content})) + + +(defn get-executor [style content model-path input-size ctx] + (let [out (sym/group [style content]) + ;; make executor + [arg-shapes output-shapes aux-shapes] (sym/infer-shape out {:data [1 3 (first input-size) (second input-size)]}) + arg-names (sym/list-arguments out) + arg-map (zipmap arg-names (map #(ndarray/zeros % {:ctx ctx}) arg-shapes)) + grad-map {"data" (ndarray/copy-to (get arg-map "data") ctx)} + ;; init with pre-training weights + ;;; I'm not sure this is being set properly + pretrained (do (ndarray/load model-path)) + arg-map (into {} (mapv (fn [[k v]] + (let [pretrained-key (str "arg:" k)] + (if (and (get pretrained pretrained-key) (not= "data" k)) + (do (ndarray/set v (get pretrained pretrained-key)) + [k v]) + [k v]))) + arg-map)) + exec (sym/bind out ctx arg-map grad-map) + outs (executor/outputs exec)] + {:executor exec + :data (get arg-map "data") + :data-grad (get grad-map "data") + :style (into [] (butlast outs)) + :content (last outs) + :arg-map arg-map})) diff --git a/contrib/clojure-package/examples/pre-trained-models/.gitignore b/contrib/clojure-package/examples/pre-trained-models/.gitignore new file mode 100644 index 000000000000..c53038ec0e3d --- /dev/null +++ b/contrib/clojure-package/examples/pre-trained-models/.gitignore @@ -0,0 +1,11 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ diff --git a/contrib/clojure-package/examples/pre-trained-models/README.md b/contrib/clojure-package/examples/pre-trained-models/README.md new file mode 100644 index 000000000000..751109f7bb13 --- /dev/null +++ b/contrib/clojure-package/examples/pre-trained-models/README.md @@ -0,0 +1,34 @@ +# pre-trained-models + +This shows examples of how to use the pretrained models. MXNet comes with a number of pretrained models +https://mxnet.incubator.apache.org/model_zoo/index.html + + +## Predict Image from pretrained models + +From the example on https://mxnet.incubator.apache.org/tutorials/python/predict_image.html + + +The `predict-image.clj` file loads up the pre-trained resnet-152 model and uses it to predict the classifications from images on the internet + +*To use run download-reset-152.sh to get the model params and json * + + +## Fine Tune from pretrained models + +From the finetune example https://mxnet.incubator.apache.org/faq/finetune.html + +The `fine-tune.clj` file loads up the samller resnet-50 model and adds a fine tune layer to reclassify the caltech iamge set + +*To use run download-resnet-50.sh to get the model params and json and download-caltech.sh to get the pregenerated rec files* + +You can run the fine tune example by doing `lein run` (cpu) + +You can control the devices you run on by doing: + +`lein run :cpu 2` - This will run on 2 cpu devices +`lein run :gpu 1` - This will run on 1 gpu device +`lein run :gpu 2` - This will run on 2 gpu devices + + + diff --git a/contrib/clojure-package/examples/pre-trained-models/download-caltech.sh b/contrib/clojure-package/examples/pre-trained-models/download-caltech.sh new file mode 100755 index 000000000000..8ad8acaffe56 --- /dev/null +++ b/contrib/clojure-package/examples/pre-trained-models/download-caltech.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +mkdir -p caltech-256 +cd caltech-256 +wget http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec +wget http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec +cd .. diff --git a/contrib/clojure-package/examples/pre-trained-models/download-resnet-152.sh b/contrib/clojure-package/examples/pre-trained-models/download-resnet-152.sh new file mode 100755 index 000000000000..b3aa7668f751 --- /dev/null +++ b/contrib/clojure-package/examples/pre-trained-models/download-resnet-152.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +mkdir -p model +cd model +wget http://data.mxnet.io/models/imagenet-11k/resnet-152/resnet-152-symbol.json +wget http://data.mxnet.io/models/imagenet-11k/resnet-152/resnet-152-0000.params +wget http://data.mxnet.io/models/imagenet-11k/synset.txt +cd .. + diff --git a/contrib/clojure-package/examples/pre-trained-models/download-resnet-50.sh b/contrib/clojure-package/examples/pre-trained-models/download-resnet-50.sh new file mode 100755 index 000000000000..3286f51e8e18 --- /dev/null +++ b/contrib/clojure-package/examples/pre-trained-models/download-resnet-50.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +mkdir -p model +cd model +wget http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-50-symbol.json +wget http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-50-0000.params +cd .. + diff --git a/contrib/clojure-package/examples/pre-trained-models/project.clj b/contrib/clojure-package/examples/pre-trained-models/project.clj new file mode 100644 index 000000000000..254f34a98776 --- /dev/null +++ b/contrib/clojure-package/examples/pre-trained-models/project.clj @@ -0,0 +1,24 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject pre-trained-models "0.1.0-SNAPSHOT" + :description "Example of using pre-trained models with MXNet" + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"] + [net.mikera/imagez "0.12.0"] + [thinktopic/think.image "0.4.16"]] + :main pre-trained-models.fine-tune) diff --git a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj new file mode 100644 index 000000000000..a73aa3003633 --- /dev/null +++ b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj @@ -0,0 +1,132 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns pre-trained-models.fine-tune + (:require [clojure.string :as string] + [org.apache.clojure-mxnet.callback :as callback] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.initializer :as init] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.symbol :as sym]) + (:gen-class)) + +;;; From the finetune example https://mxnet.incubator.apache.org/faq/finetune.html + +;; run download-resnet-50.sh to get the model params and json +;; and download-caltech.sh to get the pregenerated rec files + +(def model-dir "model") +(def batch-size 16) + + +;;; image set is http://www.vision.caltech.edu/Image_Datasets/Caltech101/ +;; Pictures of objects belonging to 101 categories. About 40 to 800 images per category. Most categories have about 50 images + +(def train-iter (mx-io/image-record-iter + {:path-imgrec "caltech-256/caltech-256-60-train.rec" + :data-name "data" + :label-name "softmax_label" + :batch-size batch-size + :data-shape [3 224 224] + :shuffle true + :rand-crop true + :rand-mirror true})) + +(def val-iter (mx-io/image-record-iter + {:path-imgrec "caltech-256/caltech-256-60-val.rec" + :data-name "data" + :label-name "softmax_label" + :batch-size batch-size + :data-shape [3 224 224] + :rand-crop false + :rand-mirror false})) + +(defn get-model [] + (let [mod (m/load-checkpoint {:prefix (str model-dir "/resnet-50") :epoch 0})] + {:msymbol (m/symbol mod) + :arg-params (m/arg-params mod) + :aux-params (m/aux-params mod)})) + +(defn get-fine-tune-model + "msymbol: the pretrained network symbol + arg-params: the argument parameters of the pretrained model + num-classes: the number of classes for the fine-tune datasets + layer-name: the layer name before the last fully-connected layer" + [{:keys [msymbol arg-params num-classes layer-name] + :or {layer-name "flatten0"}}] + (let [all-layers (sym/get-internals msymbol) + net (sym/get all-layers (str layer-name "_output"))] + {:net (as-> net data + (sym/fully-connected "fc1" {:data data :num-hidden num-classes}) + (sym/softmax-output "softmax" {:data data})) + :new-args (->> arg-params + (remove (fn [[k v]] (string/includes? k "fc1"))) + (into {}))})) + +(defn fit [devs msymbol arg-params aux-params] + (let [mod (-> (m/module msymbol {:contexts devs}) + (m/bind {:data-shapes (mx-io/provide-data train-iter) :label-shapes (mx-io/provide-label val-iter)}) + (m/init-params {:arg-params arg-params :aux-params aux-params + :allow-missing true}))] + (m/fit mod + {:train-data train-iter + :eval-data val-iter + :num-epoch 1 + :fit-params (m/fit-params {:intializer (init/xavier {:rand-type "gaussian" + :factor-type "in" + :magnitude 2}) + :batch-end-callback (callback/speedometer batch-size 10)})}))) + +(defn fine-tune! [devs] + (let [{:keys [msymbol arg-params aux-params] :as model} (get-model) + {:keys [net new-args]} (get-fine-tune-model (merge model {:num-classes 256}))] + (fit devs net new-args arg-params))) + +(defn -main [& args] + (let [[dev dev-num] args + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range (Integer/parseInt (or dev-num "1")))) + (mapv #(context/cpu %) (range (Integer/parseInt (or dev-num "1")))))] + (println "Running with context devices of" devs) + (fine-tune! devs)) + ) + + +(comment + + (fine-tune! [(context/cpu)]) + +;INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [10] Speed: 3.61 samples/sec Train-accuracy=0.000000 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [20] Speed: 3.49 samples/sec Train-accuracy=0.005952 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [30] Speed: 3.58 samples/sec Train-accuracy=0.012097 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [40] Speed: 3.49 samples/sec Train-accuracy=0.013720 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [50] Speed: 3.51 samples/sec Train-accuracy=0.017157 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [60] Speed: 3.56 samples/sec Train-accuracy=0.017418 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [70] Speed: 3.56 samples/sec Train-accuracy=0.023768 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [80] Speed: 3.10 samples/sec Train-accuracy=0.024691 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [90] Speed: 3.27 samples/sec Train-accuracy=0.028846 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [100] Speed: 3.42 samples/sec Train-accuracy=0.033416 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [110] Speed: 3.46 samples/sec Train-accuracy=0.034910 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [120] Speed: 3.44 samples/sec Train-accuracy=0.040806 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [130] Speed: 3.41 samples/sec Train-accuracy=0.043893 +;; INFO ml.dmlc.mxnet.Callback$Speedometer: Epoch[0] Batch [140] Speed: 3.42 samples/sec Train-accuracy=0.045213 + +) + + diff --git a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj new file mode 100644 index 000000000000..ee25e4ce5044 --- /dev/null +++ b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj @@ -0,0 +1,115 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns pre-trained-models.predict-image + (:require [clojure.java.io :as io] + [clojure.string :as string] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.symbol :as sym] + [mikera.image.core :as img] + [think.image.pixel :as pixel])) + +;; based on https://mxnet.incubator.apache.org/tutorials/python/predict_image.html + +;; run download-reset-152.sh to get the model params and json + +(def model-dir "model") +(def num-channels 3) +(def h 224) +(def w 224) + +(defn download [uri file] + (with-open [in (io/input-stream uri) + out (io/output-stream file)] + (io/copy in out))) + + +(defn get-image [url show?] + (let [fname "test-image.jpg" + _ (download url fname) + image (-> (img/load-image fname) + (img/resize h w)) + pixels (img/get-pixels image) + rgb-pixels (reduce (fn [result pixel] + (let [[rs gs bs] result + [r g b _] (pixel/unpack-pixel pixel)] + [(conj rs r) (conj gs g) (conj bs b)])) + [[] [] []] + pixels)] + (when show? (img/show image)) + (-> rgb-pixels + (flatten) + (ndarray/array [1 num-channels h w])))) + +(defn predict [img-url show?] + (let [mod (m/load-checkpoint {:prefix (str model-dir "/resnet-152") :epoch 0}) + labels (-> (slurp (str model-dir "/synset.txt")) + (string/split #"\n")) + nd-img (get-image img-url show?) + prob (-> mod + (m/bind {:for-training false :data-shapes [{:name "data" :shape [1 num-channels h w]}]}) + (m/forward {:data [nd-img]}) + (m/outputs) + (ffirst)) + prob-with-labels (mapv (fn [p l] {:prob p :label l}) + (ndarray/->vec prob) + labels)] + (->> (sort-by :prob prob-with-labels) + (reverse) + (take 5)))) + +(defn feature-extraction [] + (let [nd-img (get-image "http://animalsbirds.com/wp-content/uploads/2016/07/Animal-Cat-HD-Wallpapers.jpg" false) + mod (-> (m/load-checkpoint {:prefix (str model-dir "/resnet-152") :epoch 0}) + (m/bind {:for-training false :data-shapes [{:name "data" :shape [1 num-channels h w]}]})) + fe-sym (-> (m/symbol mod) + (sym/get-internals) + (sym/get "flatten0_output")) + fe-mod (-> (m/module fe-sym {:label-names nil}) + (m/bind {:for-training false :data-shapes [{:name "data" :shape [1 num-channels h w]}]}) + (m/init-params {:arg-params (m/arg-params mod) :aux-params (m/aux-params mod)}))] + (-> fe-mod + (m/forward {:data [nd-img]}) + (m/outputs) + (ffirst) + (ndarray/shape) + (mx-shape/->vec)))) + +(comment + + (predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg") + ;; ({:prob 0.69066674, :label "n02122948 kitten, kitty"} + ;; {:prob 0.04466057, :label "n01323155 kit"} + ;; {:prob 0.029682875, :label "n01318894 pet"} + ;; {:prob 0.028944906, :label "n02122878 tabby, queen"} + ;; {:prob 0.027530408, :label "n01322221 baby"}) + + (predict "http://thenotoriouspug.com/wp-content/uploads/2015/01/Pug-Cookie-1920x1080-1024x576.jpg" true) + ;; ({:prob 0.44412872, :label "n02110958 pug, pug-dog"} + ;; {:prob 0.093773685, + ;; :label "n13905792 wrinkle, furrow, crease, crinkle, seam, line"} + ;; {:prob 0.02395489, :label "n01318894 pet"} + ;; {:prob 0.023736171, + ;; :label "n02084732 pooch, doggie, doggy, barker, bow-wow"} + ;; {:prob 0.023329297, :label "n02083346 canine, canid"}) + + (feature-extraction) ;=> [1 2048] + + ) + diff --git a/contrib/clojure-package/examples/profiler/.gitignore b/contrib/clojure-package/examples/profiler/.gitignore new file mode 100644 index 000000000000..c53038ec0e3d --- /dev/null +++ b/contrib/clojure-package/examples/profiler/.gitignore @@ -0,0 +1,11 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ diff --git a/contrib/clojure-package/examples/profiler/README.md b/contrib/clojure-package/examples/profiler/README.md new file mode 100644 index 000000000000..d8a98d35dbe3 --- /dev/null +++ b/contrib/clojure-package/examples/profiler/README.md @@ -0,0 +1,6 @@ +# profiler + +An example of using the profiler. + +To run use `lein run` +A file will be generated in the directory afterwards `profile-matmul-20iter.json` diff --git a/contrib/clojure-package/examples/profiler/project.clj b/contrib/clojure-package/examples/profiler/project.clj new file mode 100644 index 000000000000..ca2fad25137e --- /dev/null +++ b/contrib/clojure-package/examples/profiler/project.clj @@ -0,0 +1,21 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject profiler "0.1.0-SNAPSHOT" + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"]] + :main profiler.core) diff --git a/contrib/clojure-package/examples/profiler/src/profiler/core.clj b/contrib/clojure-package/examples/profiler/src/profiler/core.clj new file mode 100644 index 000000000000..e366c578c551 --- /dev/null +++ b/contrib/clojure-package/examples/profiler/src/profiler/core.clj @@ -0,0 +1,59 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns profiler.core + (:require [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.profiler :as profiler] + [org.apache.clojure-mxnet.random :as random] + [org.apache.clojure-mxnet.symbol :as sym]) + (:gen-class)) + +(def profiler-mode "symbolic") ;; can be symbolic, imperative, api, mem +(def output-path ".") ;; the profile file output directory +(def profiler-name "profile-matmul-20iter.json") +(def iter-num 100) +(def begin-profiling-iter 50) +(def end-profiling-iter 70) +(def gpu? false) + +(defn run [] + (let [shape [4096 4096] + path (str output-path "/" profiler-name) + ctx (if gpu? (context/gpu) (context/cpu)) + kwargs {:filename path + (keyword (str "profile-" profiler-mode)) 1} + C (sym/dot "dot" [(sym/variable "A") (sym/variable "B")]) + a (random/uniform -1.0 1.0 shape {:ctx ctx}) + b (random/uniform -1.0 1.0 shape {:ctx ctx}) + exec (sym/bind C ctx {"A" [a] "B" [b]})] + + (profiler/profiler-set-config kwargs) + (doseq [i (range iter-num)] + (when (= i begin-profiling-iter) + (profiler/profiler-set-state "run")) + (when (= i end-profiling-iter) + (profiler/profiler-set-state "stop")) + (-> exec + (executor/forward) + (executor/outputs) + (first) + (ndarray/wait-to-read))))) + +(defn -main [& args] + (run)) diff --git a/contrib/clojure-package/examples/rnn/.gitignore b/contrib/clojure-package/examples/rnn/.gitignore new file mode 100644 index 000000000000..c53038ec0e3d --- /dev/null +++ b/contrib/clojure-package/examples/rnn/.gitignore @@ -0,0 +1,11 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ diff --git a/contrib/clojure-package/examples/rnn/README.md b/contrib/clojure-package/examples/rnn/README.md new file mode 100644 index 000000000000..cad3909447b1 --- /dev/null +++ b/contrib/clojure-package/examples/rnn/README.md @@ -0,0 +1,20 @@ +# rnn + + +Demonstration of LSTM RNN trainined using Obamas text + +## Usage + + +run `./get_data.sh to download the training corpus as well as pretrained model. + +Run `lein run` to start training the corpus from scratch for 2 epochs and then +show the result of training after 75 epochs (cpu) + +You can control the devices you run on by doing: + +`lein run :cpu 2` - This will run on 2 cpu devices +`lein run :gpu 1` - This will run on 1 gpu device +`lein run :gpu 2` - This will run on 2 gpu devices + + diff --git a/contrib/clojure-package/examples/rnn/get_data.sh b/contrib/clojure-package/examples/rnn/get_data.sh new file mode 100755 index 000000000000..4e4a2dc3e4a1 --- /dev/null +++ b/contrib/clojure-package/examples/rnn/get_data.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +mkdir data +cd data +wget http://data.mxnet.io/mxnet/data/char_lstm.zip +unzip char_lstm.zip +cd .. diff --git a/contrib/clojure-package/examples/rnn/project.clj b/contrib/clojure-package/examples/rnn/project.clj new file mode 100644 index 000000000000..ff00a10fc289 --- /dev/null +++ b/contrib/clojure-package/examples/rnn/project.clj @@ -0,0 +1,22 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject rnn "0.1.0-SNAPSHOT" + :description "RNN example" + :main rnn.train-char-rnn + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"]]) diff --git a/contrib/clojure-package/examples/rnn/src/rnn/lstm.clj b/contrib/clojure-package/examples/rnn/src/rnn/lstm.clj new file mode 100644 index 000000000000..ec8e9e8b6f85 --- /dev/null +++ b/contrib/clojure-package/examples/rnn/src/rnn/lstm.clj @@ -0,0 +1,192 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns rnn.lstm + (:require [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.symbol :as sym])) + +(defn lstm-param [i2h-weight i2h-bias + h2h-weight h2h-bias] + {:i2h-weight i2h-weight :i2h-bias i2h-bias + :h2h-weight h2h-weight :h2h-bias h2h-bias}) + +(defn lstm-state [c h] + {:c c :h h}) + +(defn lstm [num-hidden in-data prev-state param seq-idx layer-idx dropout] + (let [in-dataa (if (pos? dropout) + (sym/dropout {:data in-data :p dropout}) + in-data) + i2h (sym/fully-connected (str "t" seq-idx "_l" layer-idx "_i2h") + {:data in-dataa :weight (:i2h-weight param) + :bias (:i2h-bias param) :num-hidden (* num-hidden 4)}) + h2h (sym/fully-connected (str "t" seq-idx "_l" layer-idx "_h2h") + {:data (:h prev-state) :weight (:h2h-weight param) + :bias (:h2h-bias param) :num-hidden (* num-hidden 4)}) + gates (sym/+ i2h h2h) + slice-gates (sym/slice-channel (str "t" seq-idx "_l" layer-idx "_slice") + {:data gates :num-outputs 4}) + in-gate (sym/activation {:data (sym/get slice-gates 0) :act-type "sigmoid"}) + in-transform (sym/activation {:data (sym/get slice-gates 1) :act-type "tanh"}) + forget-gate (sym/activation {:data (sym/get slice-gates 2) :act-type "sigmoid"}) + out-gate (sym/activation {:data (sym/get slice-gates 3) :act-type "sigmoid"}) + next-c (sym/+ (sym/* forget-gate (:c prev-state)) + (sym/* in-gate in-transform)) + next-h (sym/* out-gate (sym/activation {:data next-c :act-type "tanh"}))] + (lstm-state next-c next-h))) + +(defn lstm-unroll [num-lstm-layer seq-len input-size num-hidden num-embed num-label dropout] + (let [embed-weight (sym/variable "embed_weight") + cls-weight (sym/variable "cls_weight") + cls-bias (sym/variable "cls_bias") + param-cells (mapv (fn [i] + (lstm-param (sym/variable (str "l" i "_i2h_weight")) + (sym/variable (str "l" i "_i2h_bias")) + (sym/variable (str "l" i "_h2h_weight")) + (sym/variable (str "l" i "_h2h_bias")))) + (range 0 num-lstm-layer)) + last-states (mapv (fn [i] + (lstm-state (sym/variable (str "l" i "_init_c_beta")) + (sym/variable (str "l" i "_init_h_beta")))) + (range 0 num-lstm-layer)) + ;; embedding layer + data (sym/variable "data") + label (sym/variable "softmax_label") + embed (sym/embedding "embed" {:data data :input-dim input-size :weight embed-weight + :output-dim num-embed}) + wordvec (sym/slice-channel {:data embed :num-outputs seq-len :squeeze-axis 1}) + dp-ratio 0 + ;; stack lstm + hidden-all (doall (for [seq-idx (range seq-len)] + (let [hidden (:h (last (loop [i 0 + hidden (sym/get wordvec seq-idx) + next-states []] + (if (= i num-lstm-layer) + next-states + (let [dp-ratio (if (zero? i) 0 dropout) + next-state (lstm num-hidden + hidden + (get last-states i) + (get param-cells i) + seq-idx + i + dp-ratio)] + (recur (inc i) + (:h next-state) + (conj next-states next-state)))))))] + (if (pos? dropout) + (sym/dropout {:data hidden :p dropout}) + hidden)))) + hidden-concat (sym/concat "concat" nil hidden-all {:dim 0}) + pred (sym/fully-connected "pred" {:data hidden-concat :num-hidden num-label + :weight cls-weight :bias cls-bias}) + label (sym/transpose {:data label}) + label (sym/reshape {:data label :target-shape [0]}) + sm (sym/softmax-output "softmax" {:data pred :label label})] + sm)) + +(defn lstm-inference-symbol [num-lstm-layer input-size num-hidden + num-embed num-label dropout] + (let [seq-idx 0 + embed-weight (sym/variable "embed_weight") + cls-weight (sym/variable "cls_weight") + cls-bias (sym/variable "cls_bias") + param-cells (mapv (fn [i] + (lstm-param (sym/variable (str "l" i "_i2h_weight")) + (sym/variable (str "l" i "_i2h_bias")) + (sym/variable (str "l" i "_h2h_weight")) + (sym/variable (str "l" i "_h2h_bias")))) + (range 0 num-lstm-layer)) + last-states (mapv (fn [i] + (lstm-state (sym/variable (str "l" i "_init_c_beta")) + (sym/variable (str "l" i "_init_h_beta")))) + (range 0 num-lstm-layer)) + data (sym/variable "data") + dp-ratio 0 + ;; stack lstm + next-states (loop [i 0 + hidden (sym/embedding "embed" {:data data :input-dim input-size :weight embed-weight :output-dim num-embed}) + next-states []] + (if (= i num-lstm-layer) + next-states + (let [dp-ratio (if (zero? i) 0 dropout) + next-state (lstm num-hidden + hidden + (get last-states i) + (get param-cells i) + seq-idx + i + dp-ratio)] + (recur (inc i) + (:h next-state) + (conj next-states next-state))))) + ;;; decoder + hidden (:h (last next-states)) + hidden (if (pos? dropout) (sym/dropout {:data hidden :p dropout}) hidden) + fc (sym/fully-connected "pred" {:data hidden :num-hidden num-label + :weight cls-weight :bias cls-bias}) + sm (sym/softmax-output "softmax" {:data fc}) + outs (into [sm] (mapcat (fn [next-s] (vals next-s)) next-states))] + (sym/group outs))) + +(defn lstm-inference-model [{:keys [num-lstm-layer input-size num-hidden + num-embed num-label arg-params + ctx dropout] + :or {ctx (context/cpu) + dropout 0.0}}] + + (let [lstm-sym (lstm-inference-symbol num-lstm-layer + input-size + num-hidden + num-embed + num-label + dropout) + batch-size 1 + init-c (into {} (map (fn [l] + {(str "l" l "_init_c_beta") [batch-size num-hidden]}) + (range num-lstm-layer))) + init-h (into {} (map (fn [l] + {(str "l" l "_init_h_beta") [batch-size num-hidden]})) + (range num-lstm-layer)) + data-shape {"data" [batch-size]} + input-shape (merge init-c init-h data-shape) + exec (sym/simple-bind lstm-sym ctx input-shape) + exec-arg-map (executor/arg-map exec) + states-map (zipmap (mapcat (fn [i] [(str "l" i "_init_c_beta") + (str "l" i "_init_h_beta")]) + (range num-lstm-layer)) + (rest (executor/outputs exec)))] + (doseq [[k v] arg-params] + (if-let [target-v (get exec-arg-map k)] + (when (and (not (get input-shape k)) + (not= "softmax_label" k)) + (ndarray/copy-to v target-v)))) + {:exec exec + :states-map states-map})) + +(defn forward [{:keys [exec states-map] :as lstm-model} input-data new-seq] + (when new-seq + (doseq [[k v] states-map] + (ndarray/set (get (executor/arg-map exec) k) 0))) + (do + (ndarray/copy-to input-data (get (executor/arg-map exec) "data")) + (executor/forward exec) + (doseq [[k v] states-map] + (ndarray/copy-to v (get (executor/arg-map exec) k))) + (first (executor/outputs exec)))) diff --git a/contrib/clojure-package/examples/rnn/src/rnn/test_char_rnn.clj b/contrib/clojure-package/examples/rnn/src/rnn/test_char_rnn.clj new file mode 100644 index 000000000000..35e1a18c05ce --- /dev/null +++ b/contrib/clojure-package/examples/rnn/src/rnn/test_char_rnn.clj @@ -0,0 +1,79 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns rnn.test-char-rnn + (:require [clojure.string :as string] + [rnn.util :as util] + [rnn.lstm :as lstm] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray])) + +(def data-path "data/obama.txt") +(def model-prefix) +(def start-sentence "The joke ") +(def num-hidden 512) ;; hidden unit in LSTM cell +(def num-embed 256) ;; the embedding dim (a char is mapped to 256 dim) +(def num-lstm-layer 3) ;; number of lstm layers + +(def vocab (util/build-vocab data-path)) + +(defn rnn-test [model-prefix epoch-num seq-length random?] + (let [trained-mod (m/load-checkpoint {:prefix model-prefix :epoch epoch-num}) + trained-arg-params (m/arg-params trained-mod) + model (lstm/lstm-inference-model {:num-lstm-layer 3 + :input-size (inc (count vocab)) + :num-label (inc (count vocab)) + :num-hidden num-hidden + :num-embed num-embed + :arg-params trained-arg-params}) + input-ndarray (ndarray/zeros [1]) + revert-vocab (util/make-revert-vocab vocab) + fix-dict (into [""] + (mapv #(str (get revert-vocab %)) + (sort (vals vocab)))) + random-sample random? ;; use this to do random sample or max prob + ignore-length (count start-sentence)] + (println "Starter sentence: " start-sentence) + (println "===") + (loop [i 0 + new-sentence true + output start-sentence] + (if (= seq-length i) + output + (do + (if (<= i (dec ignore-length)) + (util/make-input (get start-sentence i) vocab input-ndarray) + (util/make-input (last output) vocab input-ndarray)) + (let [prob (ndarray/->vec (lstm/forward model input-ndarray new-sentence)) + next-char (util/make-output prob fix-dict random-sample)] + (recur (inc i) + (if (= "" next-char) true false) + (if (< i (dec ignore-length)) + output + (str output next-char))))))))) + + +(comment + + (rnn-test "data/obama" 75 200 false) + ;=>"The joke that we can start by the challenges of the American people. The American people have been talking about how to compete with the streets of San Antonio who the courage to come together as one " + + (rnn-test "data/obama" 75 200 true) + ;=>"The joke before them prepared for five years ago, we only hear a chance to lose our efforts and they made striggling procedural deficit at the city between a politics in the efforts on the Edmund Pett" + ) diff --git a/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj b/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj new file mode 100644 index 000000000000..d2467609ad16 --- /dev/null +++ b/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj @@ -0,0 +1,180 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns rnn.train-char-rnn + (:require [clojure.string :as string] + [rnn.util :as util] + [rnn.lstm :as lstm] + [rnn.test-char-rnn :as test-rnn] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.callback :as callback] + [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.initializer :as init] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.module :as m]) + (:gen-class)) + +;;https://github.com/apache/incubator-mxnet/blob/master/example/rnn/old/char-rnn.ipynb + +;; batch size for training +(def batch-size 32) +;; we can support various length input +;; for this problem, we cut each input sentence to length of 129 +;; so we only need a fixed lenght bucket +(def buckets [129]) +;;hidden unit in LSTM cell +(def num-hidden 512) +;; embedding dim which is map a char to a 256 dim vector +(def num-embed 256) +;; number of lstm layer +(def num-lstm-layer 3) +;; we will show a quick demo in 2 epoch and we will see the result +;; by training 75 epoch +(def num-epoch 75) +;; learning rate +(def learning-rate 0.01) +;; we will use pure sgd without momentum +(def momentum 0.0) + +(def ctx (context/cpu)) ;; change to gpu if desired +(def data-path "data/obama.txt") +(def vocab (util/build-vocab data-path)) + +;; generate the symbol for a length +(defn sym-gen [seq-len] + (lstm/lstm-unroll num-lstm-layer seq-len (inc (count vocab)) + num-hidden num-embed (inc (count vocab)) 0.2)) + +;;; in the case of this fixed bucketing that only uses one bucket size - it is the equivalent of padpadding all sentences to a fixed length. +;; we are going to use ndarray-iter for this +;; converting the bucketing-iter over to use is todo. We could either push for the example Scala one to be included in the base package and interop with that (which would be nice for other rnn needs too) or hand convert it over ourselves + + +(defn build-training-data [path] + (let [content (slurp path) + sentences (string/split content #"\n") + max-length (first buckets) + padding-int 0] + (doall (for [sentence sentences] + (let [ids (mapv #(get vocab %) sentence)] + (if (>= (count ids) max-length) + (into [] (take max-length ids)) + (into ids (repeat (- max-length (count ids)) 0)))))))) + +(defn build-labels [train-data] + ;; want to learn the next char some rotate by 1 + (doall (mapv (fn [sent-data] (conj (into [] (rest sent-data)) 0)) + train-data))) + +(defn data-desc->map [data-desc] + (->> data-desc + (map vals) + (first) + (apply hash-map))) + + +(defn train [devs] + (let [;; initialize the states for the lstm + init-c (into {} (map (fn [l] + {(str "l" l "_init_c_beta") [batch-size num-hidden]}) + (range num-lstm-layer))) + init-h (into {} (map (fn [l] + {(str "l" l "_init_h_beta") [batch-size num-hidden]})) + (range num-lstm-layer)) + init-states (merge init-c init-h) + train-data (build-training-data data-path) + labels (build-labels train-data) + sent-len (first buckets) + train-iter (mx-io/ndarray-iter [(ndarray/array (flatten train-data) + [(count train-data) sent-len])] + {:label [(ndarray/array (flatten labels) + [(count labels) sent-len])] + :label-name "softmax_label" + :data-batch-size batch-size + :last-batch-handle "pad"}) + data-and-labels (merge (data-desc->map (mx-io/provide-data train-iter)) + (data-desc->map (mx-io/provide-label train-iter)) + init-states) + init-states-data (mapv (fn [[k v]] (ndarray/zeros v {:ctx ctx})) init-states) + rnn-sym (sym-gen (first buckets)) + + rnn-mod (-> (m/module rnn-sym {:contexts devs}) + (m/bind {:data-shapes (into (mx-io/provide-data train-iter) + (mapv (fn [[k v]] {:name k :shape v}) init-states)) + :label-shapes (mx-io/provide-label train-iter)}) + (m/init-params {:initializer (init/xavier {:factor-type "in" :magnitude 2.34})}) + (m/init-optimizer {:optimizer (optimizer/adam {:learning-rate learning-rate :wd 0.0001})})) + metric (eval-metric/custom-metric + (fn [label pred] + (let [labels (ndarray/->vec (ndarray/transpose label)) + pred-shape (ndarray/shape-vec pred) + size (apply * (ndarray/shape-vec label)) + preds (mapv #(into [] %) (doall + (partition (last pred-shape) (ndarray/->vec pred)))) + results (map-indexed + (fn [i l] + (get-in preds [i (int l)])) + labels) + result (->> results + (mapv #(Math/max (float 1e-10) (float %))) + (mapv #(Math/log %)) + (mapv #(* -1.0 %)) + (apply +))] + (float (Math/exp (/ result (count labels))))) + ) + + "perplexity")] + + ;; Train for 2 epochs and then show the results of 75 + (doseq [epoch-num (range 2)] + (println "Doing epoch " epoch-num) + (mx-io/reduce-batches + train-iter + (fn [batch-num batch] + (let [batch (mx-io/next train-iter)] + (-> rnn-mod + (m/forward (mx-io/data-batch {:data (into (mx-io/batch-data batch) init-states-data) + :label (mx-io/batch-label batch)})) + (m/update-metric metric (mx-io/batch-label batch)) + (m/backward) + (m/update)) + (when (zero? (mod batch-num 10)) + (println "Eval metric for batch-num " batch-num " is " (eval-metric/get metric))) + (inc batch-num)))) + (println "Finished epoch " epoch-num) + #_(println "Eval-metric " (eval-metric/get-and-reset metric)) + (m/save-checkpoint rnn-mod {:prefix "train-obama" :epoch epoch-num}) + (println "Testing with random 200 chars ") + (println "=====") + (println (test-rnn/rnn-test "train-obama" epoch-num 200 true)) + (println "=====")) + + (println "Showing the result after 75 epochs (pre-trained)") + (println (test-rnn/rnn-test "data/obama" 75 200 true)) + (println "====="))) + + +(defn -main [& args] + (let [[dev dev-num] args + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range (Integer/parseInt (or dev-num "1")))) + (mapv #(context/cpu %) (range (Integer/parseInt (or dev-num "1")))))] + (train devs))) diff --git a/contrib/clojure-package/examples/rnn/src/rnn/util.clj b/contrib/clojure-package/examples/rnn/src/rnn/util.clj new file mode 100644 index 000000000000..27e2132cf916 --- /dev/null +++ b/contrib/clojure-package/examples/rnn/src/rnn/util.clj @@ -0,0 +1,75 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns rnn.util + (:require [org.apache.clojure-mxnet.ndarray :as ndarray])) + +(defn build-vocab [path] + (let [content (slurp path) + vocab-map (reduce (fn [{:keys [vocab idx] :as result} c] + (if (get vocab c) + result + (-> result + (update :vocab assoc c (inc idx)) + (update :idx inc)))) + {:vocab {} :idx 0} ;; 0 is used for padding + content)] + (:vocab vocab-map))) + +(defn make-revert-vocab [vmap] + (into {} (map (fn [[k v]] [v k]) vmap))) + +(defn make-input [char vocab arr] + (let [idx (get vocab char) + tmp (ndarray/zeros [1])] + (do + (ndarray/set tmp idx) + (ndarray/set arr tmp)))) + +(defn cdf [weights] + (let [total (* 1.0 (apply + weights)) + csums (reduce (fn [cumsum w] (conj cumsum (+ (or (last cumsum) 0) w)) ) [] weights)] + (mapv #(/ % total) csums))) + + +(defn choice [population weights] + (assert (= (count population) (count weights))) + (let [cdf-vals (cdf weights) + x (rand) + idx (-> (partition-by (fn [v] (>= v x)) cdf-vals) + first + count)] + (get population idx))) + +;; we can use random output of fixed-output by choosing the largest probability +(defn make-output [prob fix-dict sample] + (let [temperature 1.0 + char (if sample + (let [scale-prob (mapv (fn [x] (if (< x 1e-6) + 1e-6 + (if (> x (- 1 1e-6)) + (- 1 1e-6) + x))) prob) + rescale (mapv (fn [x] (Math/exp (/ (Math/log x) temperature))) scale-prob) + sum (apply + rescale) + rescale (map (fn [x] (/ x sum)) rescale)] + (choice fix-dict rescale)) + (->> (zipmap prob fix-dict) + (sort-by max) + (vals) + last))] + char)) diff --git a/contrib/clojure-package/examples/scripts/get_cifar_data.sh b/contrib/clojure-package/examples/scripts/get_cifar_data.sh new file mode 100755 index 000000000000..372c7bb5781e --- /dev/null +++ b/contrib/clojure-package/examples/scripts/get_cifar_data.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +set -evx + +if [ ! -z "$MXNET_DATA_DIR" ]; then + data_path="$MXNET_DATA_DIR" +else + data_path="./data" +fi + +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" +fi + +cifar_data_path="$data_path/cifar10.zip" +if [ ! -f "$cifar_data_path" ]; then + wget http://data.mxnet.io/mxnet/data/cifar10.zip -P $data_path + cd $data_path + unzip -u cifar10.zip +fi diff --git a/contrib/clojure-package/examples/scripts/get_mnist_data.sh b/contrib/clojure-package/examples/scripts/get_mnist_data.sh new file mode 100755 index 000000000000..6f32b85f480b --- /dev/null +++ b/contrib/clojure-package/examples/scripts/get_mnist_data.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +set -evx + +if [ ! -z "$MXNET_DATA_DIR" ]; then + data_path="$MXNET_DATA_DIR" +else + data_path="./data" +fi + +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" +fi + +mnist_data_path="$data_path/mnist.zip" +if [ ! -f "$mnist_data_path" ]; then + wget http://data.mxnet.io/mxnet/data/mnist.zip -P $data_path + cd $data_path + unzip -u mnist.zip +fi diff --git a/contrib/clojure-package/examples/tutorial/.gitignore b/contrib/clojure-package/examples/tutorial/.gitignore new file mode 100644 index 000000000000..c53038ec0e3d --- /dev/null +++ b/contrib/clojure-package/examples/tutorial/.gitignore @@ -0,0 +1,11 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ diff --git a/contrib/clojure-package/examples/tutorial/README.md b/contrib/clojure-package/examples/tutorial/README.md new file mode 100644 index 000000000000..cd19eb232760 --- /dev/null +++ b/contrib/clojure-package/examples/tutorial/README.md @@ -0,0 +1,5 @@ +# tutorial + +Tutorials are based on the Scala api examples here https://mxnet.incubator.apache.org/api/scala/ndarray.html + +Start with ndarray then move onto symbol and module diff --git a/contrib/clojure-package/examples/tutorial/project.clj b/contrib/clojure-package/examples/tutorial/project.clj new file mode 100644 index 000000000000..027c1d3f3954 --- /dev/null +++ b/contrib/clojure-package/examples/tutorial/project.clj @@ -0,0 +1,21 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject tutorial "0.1.0-SNAPSHOT" + :description "MXNET tutorials" + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"]]) diff --git a/contrib/clojure-package/examples/tutorial/src/tutorial/introduction.clj b/contrib/clojure-package/examples/tutorial/src/tutorial/introduction.clj new file mode 100644 index 000000000000..094553fc5bdf --- /dev/null +++ b/contrib/clojure-package/examples/tutorial/src/tutorial/introduction.clj @@ -0,0 +1,34 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns tutorial.introduction + (:require [org.apache.clojure-mxnet.ndarray :as ndarray])) + +;; MXNet supports the Clojure programming language. The MXNet Clojure package brings flexible and efficient GPU computing and state-of-art deep learning to Clojure. It enables you to write seamless tensor/matrix computation with multiple GPUs in Clojure. It also lets you construct and customize the state-of-art deep learning models in Clojure, and apply them to tasks, such as image classification and data science challenges. + +;; You can perform tensor or matrix computation in pure Clojure: + +(def arr (ndarray/ones [2 3])) + +arr ;=> #object[ml.dmlc.mxnet.NDArray 0x482401ab "ml.dmlc.mxnet.NDArray@d8902656"] + +(ndarray/shape-vec arr) ;=> [2 3] + +(-> (ndarray/* arr 2) + (ndarray/->vec)) ;=> [2.0 2.0 2.0 2.0 2.0 2.0] + +(ndarray/shape-vec (ndarray/* arr 2)) ;=> [2 3] diff --git a/contrib/clojure-package/examples/tutorial/src/tutorial/kvstore.clj b/contrib/clojure-package/examples/tutorial/src/tutorial/kvstore.clj new file mode 100644 index 000000000000..5780ac25868e --- /dev/null +++ b/contrib/clojure-package/examples/tutorial/src/tutorial/kvstore.clj @@ -0,0 +1,74 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns tutorial.kvstore + (:require [org.apache.clojure-mxnet.kvstore :as kvstore] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.context :as context])) + +;;Basic Push and Pull +;;Provides basic operation over multiple devices (GPUs or CPUs) on a single device. + +;; Initialization +;; Letā€™s consider a simple example. It initializes a (int, NDArray) pair into the store, and then pulls the value out. + +(def kv (kvstore/create "local")) ;; create a local kvstore +(def shape [2 3]) +;;; init the kvstore with a vector of keys (strings) and ndarrays +(kvstore/init kv ["3"] [(ndarray/* (ndarray/ones shape) 2)]) +(def a (ndarray/zeros shape)) +(kvstore/pull kv ["3"] [a]) +(ndarray/->vec a) ;=> [2.0 2.0 2.0 2.0 2.0 2.0] + + +;;Push, Aggregation, and Updater +;;For any key thatā€™s been initialized, you can push a new value with the same shape to the key, as follows: + +(kvstore/push kv ["3"] [(ndarray/* (ndarray/ones shape) 8)]) +(kvstore/pull kv ["3"] [a]) +(ndarray/->vec a);=>[8.0 8.0 8.0 8.0 8.0 8.0] + +;;The data that you want to push can be stored on any device. Furthermore, you can push multiple values into the same key, where KVStore first sums all of these values, and then pushes the aggregated value, as follows: + +;; using multiple cpus instead of gpus +(def cpus [(context/cpu 0) (context/cpu 1) (context/cpu 2)]) +(def b [(ndarray/ones shape {:ctx (nth cpus 0)}) + (ndarray/ones shape {:ctx (nth cpus 1)}) + (ndarray/ones shape {:ctx (nth cpus 2)}) ]) +(kvstore/push kv ["3" "3" "3"] b) +(kvstore/pull kv "3" a) +(ndarray/->vec a) ;=> [3.0 3.0 3.0 3.0 3.0 3.0] + + +;;Pull +;;Youā€™ve already seen how to pull a single key-value pair. Similar to the way that you use the push command, you can pull the value into several devices with a single call. +(def b [(ndarray/ones shape {:ctx (context/cpu 0)}) + (ndarray/ones shape {:ctx (context/cpu 1)}) ]) +(kvstore/pull kv ["3" "3"] b) +(map ndarray/->vec b) ;=> ([3.0 3.0 3.0 3.0 3.0 3.0] [3.0 3.0 3.0 3.0 3.0 3.0]) + +;;List Key-Value Pairs +;;All of the operations that weā€™ve discussed so far are performed on a single key. KVStore also provides the interface for generating a list of key-value pairs. For a single device, use the following: + +(def ks ["5" "7" "9"]) +(kvstore/init kv ks [(ndarray/ones shape) (ndarray/ones shape) (ndarray/ones shape)]) +(kvstore/push kv ks [(ndarray/ones shape) (ndarray/ones shape) (ndarray/ones shape)]) +(def b [(ndarray/zeros shape) (ndarray/zeros shape)(ndarray/zeros shape)]) +(kvstore/pull kv ks b) +(map ndarray/->vec b);=> ([1.0 1.0 1.0 1.0 1.0 1.0] [1.0 1.0 1.0 1.0 1.0 1.0] [1.0 1.0 1.0 1.0 1.0 1.0]) + + diff --git a/contrib/clojure-package/examples/tutorial/src/tutorial/module.clj b/contrib/clojure-package/examples/tutorial/src/tutorial/module.clj new file mode 100644 index 000000000000..ad5ff133a5b0 --- /dev/null +++ b/contrib/clojure-package/examples/tutorial/src/tutorial/module.clj @@ -0,0 +1,219 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns tutorial.module + (:require [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.ndarray :as ndarray])) + +(def data-dir "data/") + +(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte"))) + (sh "../../scripts/get_mnist_data.sh")) + +;;; Load the MNIST datasets +(def train-data (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") + :label (str data-dir "train-labels-idx1-ubyte") + :label-name "softmax_label" + :input-shape [784] + :batch-size 10 + :shuffle true + :flat true + :silent false + :seed 10})) + +(def test-data (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte") + :label (str data-dir "t10k-labels-idx1-ubyte") + :input-shape [784] + :batch-size 10 + :flat true + :silent false})) + +;; The module API provides an intermediate and high-level interface for performing computation with neural networks in MXNet. Module wraps a Symbol and one or more Executors. It has both a high level and intermediate level api + +;; Preparing a module for Computation + +;; construct a module + +(let [data (sym/variable "data") + fc1 (sym/fully-connected "fc1" {:data data :num-hidden 128}) + act1 (sym/activation "relu1" {:data fc1 :act-type "relu"}) + fc2 (sym/fully-connected "fc2" {:data act1 :num-hidden 64}) + act2 (sym/activation "relu2" {:data fc2 :act-type "relu"}) + fc3 (sym/fully-connected "fc3" {:data act2 :num-hidden 10}) + out (sym/softmax-output "softmax" {:data fc3})] + out) ;=> #object[ml.dmlc.mxnet.Symbol 0x2c7b036b "ml.dmlc.mxnet.Symbol@2c7b036b"] + +;; You can also use as-> for easier threading + + +(def out (as-> (sym/variable "data") data + (sym/fully-connected "fc1" {:data data :num-hidden 128}) + (sym/activation "relu1" {:data data :act-type "relu"}) + (sym/fully-connected "fc2" {:data data :num-hidden 64}) + (sym/activation "relu2" {:data data :act-type "relu"}) + (sym/fully-connected "fc3" {:data data :num-hidden 10}) + (sym/softmax-output "softmax" {:data data}))) +;=> #'tutorial.module/out + + +;; By default, context is the CPU. If you need data parallelization, you can specify a GPU context or an array of GPU contexts. +;; like this (m/module out {:contexts [(context/gpu)]}) + +;; Before you can compute with a module, you need to call `bind` to allocate the device memory and `initParams` or `set-params` to initialize the parameters. If you simply want to fit a module, you donā€™t need to call `bind` and `init-params` explicitly, because the `fit` function automatically calls them if they are needed. + +(let [mod (m/module out)] + (-> mod + (m/bind {:data-shapes (mx-io/provide-data train-data) + :label-shapes (mx-io/provide-label train-data)}) + (m/init-params))) + +;; Now you can compute with the module using functions like `forward`, `backward`, etc. + + +;; Training, Predicting, and Evaluating + +;;Modules provide high-level APIs for training, predicting, and evaluating. To fit a module, call the `fit` function with some DataIters: + +(def mod (m/fit (m/module out) {:train-data train-data :eval-data test-data :num-epoch 1})) +;; INFO ml.dmlc.mxnet.module.BaseModule: Epoch[0] Train-accuracy=0.12521666 +;; INFO ml.dmlc.mxnet.module.BaseModule: Epoch[0] Time cost=7863 +;; INFO ml.dmlc.mxnet.module.BaseModule: Epoch[0] Validation-accuracy=0.2227 + + +;; You can pass in batch-end callbacks using batch-end-callback and epoch-end callbacks using epoch-end-callback in the `fit-params`. You can also set parameters using functions like in the fit-params like optimizer and eval-metric. To learn more about the fit-params, see the fit-param function options. To predict with a module, call `predict` with a DataIter: + +(def results (m/predict mod {:eval-data test-data})) +(first results) ;=> #object[ml.dmlc.mxnet.NDArray 0x270236e5 "ml.dmlc.mxnet.NDArray@9180e594"] + +(first (ndarray/->vec (first results))) ;=> 0.099454574 + +;;The module collects and returns all of the prediction results. For more details about the format of the return values, see the documentation for the `predict` function. + +;;When prediction results might be too large to fit in memory, use the `predict-every-batch` API + +(let [preds (m/predict-every-batch mod {:eval-data test-data})] + (mx-io/reduce-batches test-data + (fn [i batch] + (println (str "pred is " (first (get preds i)))) + (println (str "label is " (mx-io/batch-label batch))) + ;;; do something + (inc i)))) + +;;If you need to evaluate on a test set and donā€™t need the prediction output, call the `score` function with a DataIter and an EvalMetric: + +(m/score mod {:eval-data test-data :eval-metric (eval-metric/accuracy)}) ;=>["accuracy" 0.2227] + +;;This runs predictions on each batch in the provided DataIter and computes the evaluation score using the provided EvalMetric. The evaluation results are stored in metric so that you can query later. + +;;Saving and Loading Module Parameters + +;;To save the module parameters in each training epoch, use a `checkpoint` function + + +(let [save-prefix "my-model"] + (doseq [epoch-num (range 3)] + (mx-io/do-batches train-data (fn [batch + ;; do something + ])) + (m/save-checkpoint mod {:prefix save-prefix :epoch epoch-num :save-opt-states true}))) + +;; INFO ml.dmlc.mxnet.module.Module: Saved checkpoint to my-model-0000.params +;; INFO ml.dmlc.mxnet.module.Module: Saved optimizer state to my-model-0000.states +;; INFO ml.dmlc.mxnet.module.Module: Saved checkpoint to my-model-0001.params +;; INFO ml.dmlc.mxnet.module.Module: Saved optimizer state to my-model-0001.states + +;;To load the saved module parameters, call the `load-checkpoint` function: + +(def new-mod (m/load-checkpoint {:prefix "my-model" :epoch 1 :load-optimizer-states true})) +;=> #object[ml.dmlc.mxnet.module.Module 0x352c8590 "ml.dmlc.mxnet.module.Module@352c8590"] + +;;To initialize parameters, Bind the symbols to construct executors first with bind function. Then, initialize the parameters and auxiliary states by calling `init-params` function. + +(-> new-mod + (m/bind {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)}) + (m/init-params)) + +;;To get current parameters, use `params` + +(let [[arg-params aux-params] (m/params new-mod)] + {:arg-params arg-params + :aux-params aux-params}) + +;; {:arg-params +;; {"fc3_bias" +;; #object[ml.dmlc.mxnet.NDArray 0x4fcda4a0 "ml.dmlc.mxnet.NDArray@70276e89"], +;; "fc2_weight" +;; #object[ml.dmlc.mxnet.NDArray 0x33651972 "ml.dmlc.mxnet.NDArray@b2a396eb"], +;; "fc1_bias" +;; #object[ml.dmlc.mxnet.NDArray 0x3ad02326 "ml.dmlc.mxnet.NDArray@b4110d31"], +;; "fc3_weight" +;; #object[ml.dmlc.mxnet.NDArray 0x4c088d9b "ml.dmlc.mxnet.NDArray@19399ebd"], +;; "fc2_bias" +;; #object[ml.dmlc.mxnet.NDArray 0x3cca519d "ml.dmlc.mxnet.NDArray@61012c"], +;; "fc1_weight" +;; #object[ml.dmlc.mxnet.NDArray 0xea5d61c "ml.dmlc.mxnet.NDArray@b16841b4"]}, +;; :aux-params {}} + +;;To assign parameter and aux state values, use `set-params` function. + +(m/set-params new-mod {:arg-params (m/arg-params new-mod) :aux-params (m/aux-params new-mod)}) + ;=>#object[ml.dmlc.mxnet.module.Module 0x11f34e1 "ml.dmlc.mxnet.module.Module@11f34e1"] + +;;To resume training from a saved checkpoint, instead of calling `set-params`, directly call `fit`, passing the loaded parameters, so that `fit` knows to start from those parameters instead of initializing randomly: + +;; reset the training data before calling fit or you will get an error +(mx-io/reset train-data) +(mx-io/reset test-data) + +(m/fit new-mod {:train-data train-data :eval-data test-data :num-epoch 2 + :fit-params (-> (m/fit-params {:begin-epoch 1}))}) + +;;Create fit-params, and then use it to set `begin-epoch` so that fit() knows to resume from a saved epoch. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/contrib/clojure-package/examples/tutorial/src/tutorial/ndarray.clj b/contrib/clojure-package/examples/tutorial/src/tutorial/ndarray.clj new file mode 100644 index 000000000000..858316eefdc4 --- /dev/null +++ b/contrib/clojure-package/examples/tutorial/src/tutorial/ndarray.clj @@ -0,0 +1,92 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns tutorial.ndarray + (:require [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.context :as context])) + +;;The NDArray package (mxnet.ndarray) contains tensor operations similar to numpy.ndarray. The syntax is also similar, except for some additional calls for dealing with I/O and multiple devices. + +;;Create NDArray +;;Create mxnet.ndarray as follows: + +(def a (ndarray/zeros [100 50])) ;;all zero arrray of dimension 100 x 50 +(def b (ndarray/ones [256 32 128 1])) ;; all one array of dimension +(def c (ndarray/array [1 2 3 4 5 6] [2 3])) ;; array with contents of a shape 2 x 3 + +;;; There are also ways to convert to a vec or get the shape as an object or vec +(ndarray/->vec c) ;=> [1.0 2.0 3.0 4.0 5.0 6.0] +(ndarray/shape c) ;=> #object[org.apache.mxnet.Shape 0x583c865 "(2,3)"] +(ndarray/shape-vec c) ;=> [2 3] + + +;; NDArray Operations + +;; Arithmtic Operations +(def a (ndarray/ones [1 5])) +(def b (ndarray/ones [1 5])) +(-> (ndarray/+ a b) (ndarray/->vec)) ;=> [2.0 2.0 2.0 2.0 2.0] + +;; original ndarrays are unchanged +(ndarray/->vec a) ;=> [1.0 1.0 1.0 1.0 1.0] +(ndarray/->vec b) ;=> [1.0 1.0 1.0 1.0 1.0] + +;;inplace operators +(ndarray/+= a b) +(ndarray/->vec a) ;=> [2.0 2.0 2.0 2.0 2.0] + +;; other arthimetic operations are similar + +;; Slice operations + +(def a (ndarray/array [1 2 3 4 5 6] [3 2])) +(def a1 (ndarray/slice a 1)) +(ndarray/shape-vec a1) ;=> [1 2] +(ndarray/->vec a1) ;=> [3.0 4.0] + +(def a2 (ndarray/slice a 1 3)) +(ndarray/shape-vec a2) ;=>[2 2] +(ndarray/->vec a2) ;=> [3.0 4.0 5.0 6.0] + +;; Dot Product + +(def arr1 (ndarray/array [1 2] [1 2])) +(def arr2 (ndarray/array [3 4] [2 1])) +(def res (ndarray/dot arr1 arr2)) +(ndarray/shape-vec res) ;=> [1 1] +(ndarray/->vec res) ;=> [11.0] + +;;Save and Load NDArray +;;You can use MXNet functions to save and load a map of NDArrays from file systems, as follows: + +(ndarray/save "filename" {"arr1" arr1 "arr2" arr2}) +;; you can also do "s3://path" or "hdfs" + +;; to load +(def from-file (ndarray/load "filename")) +from-file ;=>{"arr1" #object[org.apache.mxnet.NDArray 0x6115ba61 "org.apache.mxnet.NDArray@43d85753"], "arr2" #object[org.apache.mxnet.NDArray 0x374b5eff "org.apache.mxnet.NDArray@5c93def4"]} + +;;Multi-Device Support + +;;Device information is stored in the mxnet.Context structure. When creating NDArray in MXNet, you can use the context argument (the default is the CPU context) to create arrays on specific devices as follows: + +(def cpu-a (ndarray/zeros [100 200])) +(ndarray/context cpu-a) ;=> #object[org.apache.mxnet.Context 0x3f376123 "cpu(0)"] + +(def gpu-b (ndarray/zeros [100 200] {:ctx (context/gpu 0)})) ;; to use with gpu + +;;Currently, we do not allow operations among arrays from different contexts. To manually enable this, use the copyto function to copy the content to different devices, and continue computation: diff --git a/contrib/clojure-package/examples/tutorial/src/tutorial/symbol.clj b/contrib/clojure-package/examples/tutorial/src/tutorial/symbol.clj new file mode 100644 index 000000000000..cfb680bffffc --- /dev/null +++ b/contrib/clojure-package/examples/tutorial/src/tutorial/symbol.clj @@ -0,0 +1,141 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns tutorial.symbol + (:require [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.context :as context])) + +;; How to compose symbols +;;The symbolic API provides a way to configure computation graphs. You can configure the graphs either at the level of neural network layer operations or as fine-grained operations. + +;;The following example configures a two-layer neural network. + +(def data (sym/variable "data")) +(def fc1 (sym/fully-connected "fc1" {:data data :num-hidden 128})) +(def act1 (sym/activation "act1" {:data fc1 :act-type "relu"})) +(def fc2 (sym/fully-connected "fc2" {:data act1 :num-hidden 64})) +(def net (sym/softmax-output "out" {:data fc2})) + +;; you could also combine this more dynamically with +(as-> (sym/variable "data") data + (sym/fully-connected "fc1" {:data data :num-hidden 128}) + (sym/activation "act1" {:data data :act-type "relu"}) + (sym/fully-connected "fc2" {:data data :num-hidden 64}) + (sym/softmax-output "out" {:data data})) + +net ;=> #object[ml.dmlc.mxnet.Symbol 0x38c72806 "ml.dmlc.mxnet.Symbol@38c72806"] + +;; Each symbol takes a (unique) string name. NDArray and Symbol both represent a single tensor. Operators represent the computation between tensors. Operators take symbol (or NDArray) as inputs and might also additionally accept other hyperparameters such as the number of hidden neurons (num_hidden) or the activation type (act_type) and produce the output. + +;; We can view a symbol simply as a function taking several arguments. And we can retrieve those arguments with the following method call: + +;;We can view a symbol simply as a function taking several arguments. And we can retrieve those arguments with the following method call: + +(sym/list-arguments net) + ;=> ["data" "fc1_weight" "fc1_bias" "fc2_weight" "fc2_bias" "out_label"] + +;; These arguments are the parameters and inputs needed by each symbol: + +;; data: Input data needed by the variable data. +;; fc1_weight and fc1_bias: The weight and bias for the first fully connected layer fc1. +;; fc2_weight and fc2_bias: The weight and bias for the second fully connected layer fc2. +;; out_label: The label needed by the loss. + +;;We can also specify the names explicitly: +(def net (sym/variable "data")) +(def w (sym/variable "myweight")) +(def net (sym/fully-connected "fc1" {:data net :weight w :num-hidden 128})) + +(sym/list-arguments net) +;=> ["data" "fc1_weight" "fc1_bias" "fc2_weight" "fc2_bias" "out_label" "myweight" "fc1_bias"] + +;;In the above example, FullyConnected layer has 3 inputs: data, weight, bias. When any input is not specified, a variable will be automatically generated for it. + + +;; More complicated composition + +;;MXNet provides well-optimized symbols for layers commonly used in deep learning (see src/operator). We can also define new operators in Python. The following example first performs an element-wise add between two symbols, then feeds them to the fully connected operator: + +(def lhs (sym/variable "data1")) +(def rhs (sym/variable "data2")) +(def net (sym/fully-connected "fc1" {:data (sym/+ lhs rhs) :num-hidden 128})) +(sym/list-arguments net) ;=> ["data1" "data2" "fc1_weight" "fc1_bias"] + +;; Group Multiple Symbols +;;To construct neural networks with multiple loss layers, we can use mxnet.sym.Group to group multiple symbols together. The following example groups two outputs: + +(def net (sym/variable "data")) +(def fc1 (sym/fully-connected {:data net :num-hidden 128})) +(def net2 (sym/activation {:data fc1 :act-type "relu"})) +(def out1 (sym/softmax-output {:data net2})) +(def out2 (sym/linear-regression-output {:data net2})) +(def group (sym/group [out1 out2])) +(sym/list-outputs group);=> ["softmaxoutput0_output" "linearregressionoutput0_output"] + + +;; Symbol Manipulation +;; One important difference of Symbol compared to NDArray is that we first declare the computation and then bind the computation with data to run. + +;; In this section, we introduce the functions to manipulate a symbol directly. But note that, most of them are wrapped by the module package. + +;; Shape and Type Inference +;; For each symbol, we can query its arguments, auxiliary states and outputs. We can also infer the output shape and type of the symbol given the known input shape or type of some arguments, which facilitates memory allocation. +(sym/list-arguments fc1) ;=> ["data" "fullyconnected1_weight" "fullyconnected1_bias"] +(sym/list-outputs fc1) ;=> ["fullyconnected1_output"] + +;; infer the shapes given the shape of the input arguments +(let [[arg-shapes out-shapes] (sym/infer-shape fc1 {:data [2 1]})] + {:arg-shapes arg-shapes + :out-shapes out-shapes}) ;=> {:arg-shapes ([2 1] [128 1] [128]), :out-shapes ([2 128])} + +;; Bind with Data and Evaluate +;; The symbol c constructed above declares what computation should be run. To evaluate it, we first need to feed the arguments, namely free variables, with data. + +;; We can do it by using the bind method, which accepts device context and a dict mapping free variable names to NDArrays as arguments and returns an executor. The executor provides forward method for evaluation and an attribute outputs to get all the results. + +(def a (sym/variable "a")) +(def b (sym/variable "b")) +(def c (sym/+ a b)) + +(def ex (sym/bind c {"a" (ndarray/ones [2 2]) "b" (ndarray/ones [2 2])})) +(-> (executor/forward ex) + (executor/outputs) + (first) + (ndarray/->vec));=> [2.0 2.0 2.0 2.0] + +;;We can evaluate the same symbol on GPU with different data. +;; To do this you must have the correct native library jar defined as a dependency + +;;Note In order to execute the following section on a cpu set gpu_device to (cpu). + + +(def ex (sym/bind c (context/gpu 0) {"a" (ndarray/ones [2 2]) "b" (ndarray/ones [2 2])})) + +;; Serialization +;; There are two ways to save and load the symbols. You can use the mxnet.Symbol.save and mxnet.Symbol.load functions to serialize the Symbol objects. The advantage of using save and load functions is that it is language agnostic and cloud friendly. The symbol is saved in JSON format. You can also get a JSON string directly using mxnet.Symbol.toJson. Refer to API documentation for more details. + +;; The following example shows how to save a symbol to a file, load it back, and compare two symbols using a JSON string. You can also save to S3 as well + +(def a (sym/variable "a")) +(def b (sym/variable "b")) +(def c (sym/+ a b)) +(sym/save c "symbol-c.json") +(def c2 (sym/load "symbol-c.json")) +(= (sym/to-json c) (sym/to-json c2)) ;=>true + diff --git a/contrib/clojure-package/examples/visualization/.gitignore b/contrib/clojure-package/examples/visualization/.gitignore new file mode 100644 index 000000000000..c53038ec0e3d --- /dev/null +++ b/contrib/clojure-package/examples/visualization/.gitignore @@ -0,0 +1,11 @@ +/target +/classes +/checkouts +pom.xml +pom.xml.asc +*.jar +*.class +/.lein-* +/.nrepl-port +.hgignore +.hg/ diff --git a/contrib/clojure-package/examples/visualization/README.md b/contrib/clojure-package/examples/visualization/README.md new file mode 100644 index 000000000000..8c6e2c2b3b79 --- /dev/null +++ b/contrib/clojure-package/examples/visualization/README.md @@ -0,0 +1,4 @@ +# visualization + +Run `lein run` to have a sample network visualization printed for you +"testviz.pdf" diff --git a/contrib/clojure-package/examples/visualization/project.clj b/contrib/clojure-package/examples/visualization/project.clj new file mode 100644 index 000000000000..6dc7c49f92ad --- /dev/null +++ b/contrib/clojure-package/examples/visualization/project.clj @@ -0,0 +1,22 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject visualization "0.1.0-SNAPSHOT" + :description "Visualization example" + :main visualization.core + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT"]]) diff --git a/contrib/clojure-package/examples/visualization/src/visualization/core.clj b/contrib/clojure-package/examples/visualization/src/visualization/core.clj new file mode 100644 index 000000000000..58980a0ca3ee --- /dev/null +++ b/contrib/clojure-package/examples/visualization/src/visualization/core.clj @@ -0,0 +1,50 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns visualization.core + (:require [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.visualization :as viz])) + +(defn get-symbol [] + (as-> (sym/variable "data") data + + #_(sym/convolution "conv1" {:data data :kernel [3 3] :num-filter 32 :stride [2 2]}) + #_(sym/batch-norm "bn1" {:data data}) + #_(sym/activation "relu1" {:data data :act-type "relu"}) + #_(sym/pooling "mp1" {:data data :kernel [2 2] :pool-type "max" :stride [2 2]}) + + + #_(sym/convolution "conv2" {:data data :kernel [3 3] :num-filter 32 :stride [2 2]}) + #_(sym/batch-norm "bn2" {:data data}) + #_(sym/activation "relu2" {:data data :act-type "relu"}) + #_(sym/pooling "mp2" {:data data :kernel [2 2] :pool-type "max" :stride [2 2]}) + + (sym/flatten "fl" {:data data}) + #_(sym/fully-connected "fc2" {:data data :num-hidden 10}) + (sym/softmax-output "softmax" {:data data}))) + +(defn test-viz [] + (let [dot (viz/plot-network (get-symbol) + {"data" [1 1 28 28]} + {:title "foo" :node-attrs {:shape "oval" :fixedsize "false"}})] + (viz/render dot "testviz" "./"))) + +(defn -main [& args] + (do (test-viz) + (println "Check for the testviz.pdf file in the project directory"))) + + diff --git a/contrib/clojure-package/examples/visualization/testviz b/contrib/clojure-package/examples/visualization/testviz new file mode 100644 index 000000000000..c0161e98f522 --- /dev/null +++ b/contrib/clojure-package/examples/visualization/testviz @@ -0,0 +1,32 @@ +digraph foo{ + data [label=data fixedsize=false style=filled height=0.8034 fillcolor="#8dd3c7" shape=oval width=1.3] + conv1 [label="Convolution\n3x3/2x2, 32" fixedsize=false style=filled height=0.8034 fillcolor="#fb8072" shape=oval width=1.3] + bn1 [label=bn1 fixedsize=false style=filled height=0.8034 fillcolor="#bebada" shape=oval width=1.3] + relu1 [label="Activation +relu" fixedsize=false style=filled height=0.8034 fillcolor="#ffffb3" shape=oval width=1.3] + mp1 [label="Pooling +max, 2x2/2x2" fixedsize=false style=filled height=0.8034 fillcolor="#80b1d3" shape=oval width=1.3] + conv2 [label="Convolution\n3x3/2x2, 32" fixedsize=false style=filled height=0.8034 fillcolor="#fb8072" shape=oval width=1.3] + bn2 [label=bn2 fixedsize=false style=filled height=0.8034 fillcolor="#bebada" shape=oval width=1.3] + relu2 [label="Activation +relu" fixedsize=false style=filled height=0.8034 fillcolor="#ffffb3" shape=oval width=1.3] + mp2 [label="Pooling +max, 2x2/2x2" fixedsize=false style=filled height=0.8034 fillcolor="#80b1d3" shape=oval width=1.3] + fl [label=fl fixedsize=false style=filled height=0.8034 fillcolor="#fdb462" shape=oval width=1.3] + fc2 [label="FullyConnected +10" fixedsize=false style=filled height=0.8034 fillcolor="#fb8072" shape=oval width=1.3] + softmax_label [label=softmax_label fixedsize=false style=filled height=0.8034 fillcolor="#8dd3c7" shape=oval width=1.3] + softmax [label=softmax fixedsize=false style=filled height=0.8034 fillcolor="#fccde5" shape=oval width=1.3] + conv1 -> data [ arrowtail=open dir=back label="1x28x28"] + bn1 -> conv1 [ arrowtail=open dir=back label="32x13x13"] + relu1 -> bn1 [ arrowtail=open dir=back label="32x13x13"] + mp1 -> relu1 [ arrowtail=open dir=back label="32x13x13"] + conv2 -> mp1 [ arrowtail=open dir=back label="32x6x6"] + bn2 -> conv2 [ arrowtail=open dir=back label="32x2x2"] + relu2 -> bn2 [ arrowtail=open dir=back label="32x2x2"] + mp2 -> relu2 [ arrowtail=open dir=back label="32x2x2"] + fl -> mp2 [ arrowtail=open dir=back label="32x1x1"] + fc2 -> fl [ arrowtail=open dir=back label="32"] + softmax -> fc2 [ arrowtail=open dir=back label="10"] + softmax -> softmax_label [ arrowtail=open dir=back label=""] +} diff --git a/contrib/clojure-package/examples/visualization/testviz.pdf b/contrib/clojure-package/examples/visualization/testviz.pdf new file mode 100644 index 000000000000..0acecb89119f Binary files /dev/null and b/contrib/clojure-package/examples/visualization/testviz.pdf differ diff --git a/contrib/clojure-package/project.clj b/contrib/clojure-package/project.clj new file mode 100644 index 000000000000..d7f5af4dbcc1 --- /dev/null +++ b/contrib/clojure-package/project.clj @@ -0,0 +1,54 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject org.apache.mxnet.contrib.clojure/clojure-mxnet "1.3.0-SNAPSHOT" + :description "Clojure package for MXNet" + :dependencies [[org.clojure/clojure "1.9.0"] + [t6/from-scala "0.3.0"] + + ;; Jars from Nexus + ;[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.2.1"] + ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.2.1"] + ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.2.1"] + + ;;; CI + [org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.3.0-SNAPSHOT"] + + [org.clojure/tools.logging "0.4.0"] + [org.apache.logging.log4j/log4j-core "2.8.1"] + [org.apache.logging.log4j/log4j-api "2.8.1"] + [org.slf4j/slf4j-log4j12 "1.7.25" :exclusions [org.slf4j/slf4j-api]]] + :pedantic? :skip + :plugins [[lein-codox "0.10.3" :exclusions [org.clojure/clojure]] + [lein-cloverage "1.0.10" :exclusions [org.clojure/clojure]]] + :codox {:namespaces [#"^org\.apache\.clojure-mxnet\.(?!gen).*"]} + :aliases {"generate-code" ["run" "-m" "dev.generator"]} + :repositories [["staging" {:url "https://repository.apache.org/content/repositories/staging" + ;; If a repository contains releases only setting + ;; :snapshots to false will speed up dependencies. + :snapshots true + ;; Disable signing releases deployed to this repo. + ;; (Not recommended.) + :sign-releases false + ;; You can also set the policies for how to handle + ;; :checksum failures to :fail, :warn, or :ignore. + :checksum :fail + ;; How often should this repository be checked for + ;; snapshot updates? (:daily, :always, or :never) + :update :always + ;; You can also apply them to releases only: + :releases {:checksum :fail :update :always}}]]) diff --git a/contrib/clojure-package/resources/log4j.properties b/contrib/clojure-package/resources/log4j.properties new file mode 100644 index 000000000000..e012700dce2a --- /dev/null +++ b/contrib/clojure-package/resources/log4j.properties @@ -0,0 +1,5 @@ +log4j.rootLogger=INFO, console +log4j.logger.example=DEBUG +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%-5p %c: %m%n diff --git a/contrib/clojure-package/scripts/get_cifar_data.sh b/contrib/clojure-package/scripts/get_cifar_data.sh new file mode 100755 index 000000000000..372c7bb5781e --- /dev/null +++ b/contrib/clojure-package/scripts/get_cifar_data.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +set -evx + +if [ ! -z "$MXNET_DATA_DIR" ]; then + data_path="$MXNET_DATA_DIR" +else + data_path="./data" +fi + +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" +fi + +cifar_data_path="$data_path/cifar10.zip" +if [ ! -f "$cifar_data_path" ]; then + wget http://data.mxnet.io/mxnet/data/cifar10.zip -P $data_path + cd $data_path + unzip -u cifar10.zip +fi diff --git a/contrib/clojure-package/scripts/get_mnist_data.sh b/contrib/clojure-package/scripts/get_mnist_data.sh new file mode 100755 index 000000000000..6f32b85f480b --- /dev/null +++ b/contrib/clojure-package/scripts/get_mnist_data.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +set -evx + +if [ ! -z "$MXNET_DATA_DIR" ]; then + data_path="$MXNET_DATA_DIR" +else + data_path="./data" +fi + +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" +fi + +mnist_data_path="$data_path/mnist.zip" +if [ ! -f "$mnist_data_path" ]; then + wget http://data.mxnet.io/mxnet/data/mnist.zip -P $data_path + cd $data_path + unzip -u mnist.zip +fi diff --git a/contrib/clojure-package/src/dev/generator.clj b/contrib/clojure-package/src/dev/generator.clj new file mode 100644 index 000000000000..35e7a250696a --- /dev/null +++ b/contrib/clojure-package/src/dev/generator.clj @@ -0,0 +1,329 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns dev.generator + (:require [t6.from-scala.core :as scala] + [clojure.reflect :as r] + [org.apache.clojure-mxnet.util :as util] + [clojure.pprint]) + (:import (org.apache.mxnet NDArray Symbol)) + (:gen-class)) + + +(defn clojure-case + [string] + (-> string + (clojure.string/replace #"(\s+)([A-Z][a-z]+)" "$1-$2") + (clojure.string/replace #"([A-Z]+)([A-Z][a-z]+)" "$1-$2") + (clojure.string/replace #"([a-z0-9])([A-Z])" "$1-$2") + (clojure.string/lower-case) + (clojure.string/replace #"\_" "-") + (clojure.string/replace #"\/" "div"))) + +(defn symbol-transform-param-name [parameter-types] + (->> parameter-types + (map str) + (map (fn [x] (or (util/symbol-param-coerce x) x))) + (map (fn [x] (last (clojure.string/split x #"\.")))))) + +(defn ndarray-transform-param-name [parameter-types] + (->> parameter-types + (map str) + (map (fn [x] (or (util/ndarray-param-coerce x) x))) + (map (fn [x] (last (clojure.string/split x #"\.")))))) + +(defn has-variadic? [params] + (->> params + (map str) + (filter (fn [s] (re-find #"\&" s))) + count + pos?)) + + +(defn increment-param-name [pname] + (if-let [num-str (re-find #"-\d" pname)] + (str (first (clojure.string/split pname #"-")) "-" (inc (Integer/parseInt (last (clojure.string/split num-str #"-"))))) + (str pname "-" 1))) + +(defn rename-duplicate-params [params] + (reduce (fn [known-names n] (conj known-names (if (contains? (set known-names) n) + (increment-param-name n) + n))) + [] + params)) + + +;;;;;;; symbol + +(def symbol-reflect-info (->> (:members (r/reflect Symbol)) + (map #(into {} %)))) + +(def symbol-public (filter (fn [x] (-> x :flags :public)) symbol-reflect-info)) + +(def symbol-public-no-default (->> symbol-public + (filter #(not (re-find #"org\$apache\$mxnet" (str (:name %))))) + (filter #(not (re-find #"\$default" (str (:name %))))))) + +(into #{} (mapcat :parameter-types symbol-public-no-default)) + ;#{java.lang.Object scala.collection.Seq scala.Option long double scala.collection.immutable.Map int ml.dmlc.mxnet.Executor float ml.dmlc.mxnet.Context java.lang.String scala.Enumeration$Value ml.dmlc.mxnet.Symbol int<> ml.dmlc.mxnet.Symbol<> ml.dmlc.mxnet.Shape java.lang.String<>} + +(def symbol-hand-gen-set #{"scala.Option" + "int org.apache.mxnet.Executor" + "scala.Enumeration$Value" + "org.apache.mxnet.Context" + "scala.Tuple2" + "scala.collection.Traversable"} ) + +;;; min and max have a conflicting arity of 2 with the auto gen signatures +(def symbol-filter-name-set #{"max" "min"}) + +(defn is-symbol-hand-gen? [info] + (or + (->> (:name info) + str + (get symbol-filter-name-set)) + (->> (map str (:parameter-types info)) + (into #{}) + (clojure.set/intersection symbol-hand-gen-set) + count + pos?))) + +(def symbol-public-to-hand-gen (filter is-symbol-hand-gen? symbol-public-no-default)) +(def symbol-public-to-gen (->> (remove #(contains?(->> symbol-public-to-hand-gen + (mapv :name) + (mapv str) + (set)) (str (:name %))) symbol-public-no-default))) + + +(count symbol-public-to-hand-gen) ;=> 35 mostly bind! +(count symbol-public-to-gen) ;=> 307 + +(into #{} (map :name symbol-public-to-hand-gen));=> #{arange bind ones zeros simpleBind Variable} + +(defn public-by-name-and-param-count [public-reflect-info] + (->> public-reflect-info + (group-by :name) + (map (fn [[k v]] [k (group-by #(count (:parameter-types %)) v)])) + (into {}))) + + +(defn symbol-vector-args [] + `(if (map? ~'kwargs-map-or-vec-or-sym) (~'util/empty-list) (~'util/coerce-param ~'kwargs-map-or-vec-or-sym #{"scala.collection.Seq"}))) + +(defn symbol-map-args [] + `(if (map? ~'kwargs-map-or-vec-or-sym) (util/convert-symbol-map ~'kwargs-map-or-vec-or-sym) nil)) + + +(defn add-symbol-arities [params function-name] + (if (= ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"] (mapv str params)) + [`([~'sym-name ~'attr-map ~'kwargs-map] + (~function-name ~'sym-name (~'util/convert-symbol-map ~'attr-map) (~'util/empty-list) (~'util/convert-symbol-map ~'kwargs-map))) + `([~'sym-name ~'kwargs-map-or-vec-or-sym] + (~function-name ~'sym-name nil ~(symbol-vector-args) ~(symbol-map-args))) + `([~'kwargs-map-or-vec-or-sym] + (~function-name nil nil ~(symbol-vector-args) ~(symbol-map-args)))])) + +(defn gen-symbol-function-arity [op-name op-values function-name] + (mapcat + (fn [[param-count info]] + (let [targets (->> (mapv :parameter-types info) + (apply interleave) + (mapv str) + (partition (count info)) + (mapv set)) + pnames (->> (mapv :parameter-types info) + (mapv symbol-transform-param-name) + (apply interleave) + (partition (count info)) + (mapv #(clojure.string/join "-or-" %)) + (rename-duplicate-params) + (mapv symbol)) + coerced-params (mapv (fn [p t] `(~'util/nil-or-coerce-param ~(symbol (clojure.string/replace p #"\& " "")) ~t)) pnames targets) + params (if (= #{:public :static} (:flags (first info))) + pnames + (into ['sym] pnames)) + function-body (if (= #{:public :static} (:flags (first info))) + `(~'util/coerce-return (~(symbol (str "Symbol/" op-name)) ~@coerced-params)) + `(~'util/coerce-return (~(symbol (str "." op-name)) ~'sym ~@coerced-params) + ))] + (when (not (and (> param-count 1) (has-variadic? params))) + `[( + ~params + ~function-body + ) + ~@(add-symbol-arities params function-name)]))) + op-values)) + + +(def all-symbol-functions + (for [operation (sort (public-by-name-and-param-count symbol-public-to-gen))] + (let [[op-name op-values] operation + function-name (-> op-name + str + scala/decode-scala-symbol + clojure-case + symbol)] + `(~'defn ~function-name + ~@(remove nil? (gen-symbol-function-arity op-name op-values function-name)))))) + +(def license + (str + ";; Licensed to the Apache Software Foundation (ASF) under one or more\n" + ";; contributor license agreements. See the NOTICE file distributed with\n" + ";; this work for additional information regarding copyright ownership.\n" + ";; The ASF licenses this file to You under the Apache License, Version 2.0\n" + ";; (the \"License\"); you may not use this file except in compliance with\n" + ";; the License. You may obtain a copy of the License at\n" + ";;\n" + ";; http://www.apache.org/licenses/LICENSE-2.0\n" + ";;\n" + ";; Unless required by applicable law or agreed to in writing, software\n" + ";; distributed under the License is distributed on an \"AS IS\" BASIS,\n" + ";; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + ";; See the License for the specific language governing permissions and\n" + ";; limitations under the License.\n" + ";;\n")) + +(defn write-to-file [functions ns-gen fname] + (with-open [w (clojure.java.io/writer fname)] + (.write w ns-gen) + (.write w "\n\n") + (.write w ";; Do not edit - this is auto-generated") + (.write w "\n\n") + (.write w license) + (.write w "\n\n") + (.write w "\n\n") + (doseq [f functions] + (clojure.pprint/pprint f w) + (.write w "\n")))) + +(def symbol-gen-ns "(ns org.apache.clojure-mxnet.symbol + (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max + min repeat reverse set sort take to-array empty sin + get apply shuffle]) + (:require [org.apache.clojure-mxnet.util :as util]) + (:import (org.apache.mxnet Symbol)))") + + +(defn generate-symbol-file [] + (write-to-file all-symbol-functions symbol-gen-ns "src/org/apache/clojure_mxnet/gen/symbol.clj")) + + +;;;;;;;;NDARRAY + + +(def ndarray-reflect-info (->> (:members (r/reflect NDArray)) + (map #(into {} %)))) + + +(def ndarray-public (filter (fn [x] (-> x :flags :public)) ndarray-reflect-info)) + +(def ndarray-public-no-default (->> ndarray-public + (filter #(not (re-find #"org\$apache\$mxnet" (str (:name %))))) + (filter #(not (re-find #"\$default" (str (:name %))))))) + +(def ndarray-hand-gen-set #{"org.apache.mxnet.NDArrayFuncReturn" + "org.apache.mxnet.Context" + "scala.Enumeration$Value" + "scala.Tuple2" + "scala.collection.Traversable"} ) + +(defn is-ndarray-hand-gen? [info] + (->> (map str (:parameter-types info)) + (into #{}) + (clojure.set/intersection ndarray-hand-gen-set) + count + pos?)) + + +(def ndarray-public-to-hand-gen (filter is-ndarray-hand-gen? ndarray-public-no-default)) +(def ndarray-public-to-gen (->> (remove #(contains?(->> ndarray-public-to-hand-gen + (mapv :name) + (mapv str) + (set)) (str (:name %))) ndarray-public-no-default))) + + +(count ndarray-public-to-hand-gen) ;=> 15 +(count ndarray-public-to-gen) ;=> 486 + +(map :name ndarray-public-to-hand-gen) + + + +(defn gen-ndarray-function-arity [op-name op-values] + (for [[param-count info] op-values] + (let [targets (->> (mapv :parameter-types info) + (apply interleave) + (mapv str) + (partition (count info)) + (mapv set)) + pnames (->> (mapv :parameter-types info) + (mapv ndarray-transform-param-name) + (apply interleave) + (partition (count info)) + (mapv #(clojure.string/join "-or-" %)) + (rename-duplicate-params) + (mapv symbol)) + coerced-params (mapv (fn [p t] `(~'util/coerce-param ~(symbol (clojure.string/replace p #"\& " "")) ~t)) pnames targets) + params (if (= #{:public :static} (:flags (first info))) + pnames + (into ['ndarray] pnames)) + function-body (if (= #{:public :static} (:flags (first info))) + `(~'util/coerce-return (~(symbol (str "NDArray/" op-name)) ~@coerced-params)) + `(~'util/coerce-return (~(symbol (str "." op-name)) ~'ndarray ~@coerced-params) + ))] + (when (not (and (> param-count 1) (has-variadic? params))) + `( + ~params + ~function-body + ))))) + + +(def all-ndarray-functions + (for [operation (sort (public-by-name-and-param-count ndarray-public-to-gen))] + (let [[op-name op-values] operation + function-name (-> op-name + str + scala/decode-scala-symbol + clojure-case + symbol)] + `(~'defn ~function-name + ~@(remove nil? (gen-ndarray-function-arity op-name op-values)))))) + +(def ndarray-gen-ns "(ns org.apache.clojure-mxnet.ndarray + (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max + min repeat reverse set sort take to-array empty shuffle]) + (:import (org.apache.mxnet NDArray Shape)))") + + +(defn generate-ndarray-file [] + (write-to-file all-ndarray-functions ndarray-gen-ns "src/org/apache/clojure_mxnet/gen/ndarray.clj")) + +(defn -main [& args] + (do + (println "Generating the core ndarray api from the Scala classes") + (generate-ndarray-file) + (println "Generating the core symbol api from the Scala classes") + (generate-symbol-file))) + +(comment + + ;; This generates a file with the bulk of the nd-array functions + (generate-ndarray-file) + + ;; This generates a file with the bulk of the symbol functions + (generate-symbol-file) ) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/base.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/base.clj new file mode 100644 index 000000000000..41ef821cd63b --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/base.clj @@ -0,0 +1,21 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.base + (:import (org.apache.mxnet Base))) + +(def MX_REAL_TYPE (Base/MX_REAL_TYPE)) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/callback.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/callback.clj new file mode 100644 index 000000000000..d1c6d8820f9c --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/callback.clj @@ -0,0 +1,32 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.callback + (:import (org.apache.mxnet Callback$Speedometer))) + + +;;; used to track status during epoch + +(defn speedometer + ([batch-size frequent] + (new Callback$Speedometer (int batch-size) (int frequent))) + ([batch-size] + (speedometer batch-size 50))) + +(defn invoke [callback epoch nbatch metric] + (doto callback + (.invoke (int epoch) (int nbatch) metric))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/context.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/context.clj new file mode 100644 index 000000000000..f89fd58e43cf --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/context.clj @@ -0,0 +1,43 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.context + (:import (org.apache.mxnet Context))) + + +(defn cpu + ([device-id] + (new Context "cpu" device-id)) + ([] + (cpu 0))) + +(defn gpu + ([device-id] + (new Context "gpu" device-id)) + ([] + (gpu 0))) + +(defn cpu-context [] + (cpu)) + +(defn default-context [] (cpu-context)) + +(defn device-type [context] + (.deviceType context)) + +(defn device-id [context] + (.deviceId context)) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/dtype.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/dtype.clj new file mode 100644 index 000000000000..d21fe7f7ce8a --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/dtype.clj @@ -0,0 +1,27 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.dtype + (:import (org.apache.mxnet DType))) + + +(def UINT8 (DType/UInt8)) +(def INT32 (DType/Int32)) +(def FLOAT16 (DType/Float16)) +(def FLOAT32 (DType/Float32)) +(def FLOAT64 (DType/Float64)) + diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/eval_metric.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/eval_metric.clj new file mode 100644 index 000000000000..3cddb1fabc40 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/eval_metric.clj @@ -0,0 +1,101 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.eval-metric + (:refer-clojure :exclude [get update]) + (:require [org.apache.clojure-mxnet.util :as util]) + (:import (org.apache.mxnet Accuracy TopKAccuracy F1 Perplexity MAE MSE RMSE CustomMetric))) + + +(defn accuracy + "Basic Accuracy Metric" + [] + (new Accuracy)) + +(defn top-k-accuracy + "Calculate to k predications accuracy + - top-k number of predicts (int)" + [top-k] + (new TopKAccuracy (int top-k))) + +(defn f1 + "Calculate the F1 score of a binary classification problem." + [] + (new F1)) + +(defn perplexity + "Calculate perplexity + - opts + :ignore-label Index of invalid label to ignore when counting. Usually should be -1. Include + all entries if None. + :axis The axis from prediction that was used to + compute softmax. Default is -1 which means use the last axis." + ([{:keys [ignore-label axis] :as opts + :or {axis -1}}] + (new Perplexity + (if ignore-label (util/->option (int ignore-label)) (util/->option nil)) + (int axis))) + ([] + (perplexity {}))) + +(defn mae + "Calculate Mean Absolute Error loss" + [] + (new MAE)) + +(defn mse + "Calculate Mean Squared Error loss" + [] + (new MSE)) + +(defn rmse + "Calculate Root Mean Squred Error loss" + [] + (new RMSE)) + +(defmacro custom-metric + "Custom evaluation metric that takes a NDArray function. + - f-eval Customized evaluation function that takes two ndarrays and returns a number + function must be in the form of (fn [] ) clojure style + - mname The name of the metric" + [f-eval mname] + `(new CustomMetric (util/scala-fn ~f-eval) ~mname)) + +(defn get + "Get the values of the metric in a vector form (name and value)" + [metric] + (let [[[mname] [mvalue]] (util/tuple->vec (.get metric))] + [mname mvalue])) + +(defn reset + "clear the internal statistics to an initial state" + [metric] + (doto metric + (.reset))) + +(defn update + "Update the internal evaluation" + [metric labels preds] + (doto metric + (.update (util/vec->indexed-seq labels) (util/vec->indexed-seq preds)))) + +(defn get-and-reset + "Get the values and then reset the metric" + [metric] + (let [v (get metric)] + (reset metric) + v)) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj new file mode 100644 index 000000000000..d1f8df90c96e --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj @@ -0,0 +1,102 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.executor + (:require [org.apache.clojure-mxnet.util :as util] + [clojure.reflect :as r] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.shape :as mx-shape])) + + +;; need to revisit to get all functions + +(defn ->vec [nd-array] + (vec (.toArray nd-array))) + +(defn forward + "* Calculate the outputs specified by the binded symbol. + * @param is-train whether this forward is for evaluation purpose. + * @param kwargs Additional specification of input arguments." + ([executor] + (do (.forward executor) + executor)) + ([executor is-train kwargs] + (do (.forward executor is-train (util/nil-or-coerce-param kwargs #{"scala.collection.immutable.Map"}))))) + +(defn backward + "* Do backward pass to get the gradient of arguments. + * @param ndarray-or-vec Gradient on the outputs to be propagated back. + * This parameter is only needed when bind is called + * on outputs that are not a loss function." + ([executor] + (do (.backward executor) + executor)) + ([executor ndarray-or-vec] + (do (.backward executor (if (vector? ndarray-or-vec) (into-array ndarray-or-vec) ndarray-or-vec)) + executor))) + +(defn outputs [executor] + "list all the output ndarrays" + (.outputs executor)) + +(defn grad-arrays [executor] + "list all the gradient ndarrays" + (.gradArrays executor)) + +(defn arg-arrays [executor] + "list all the argument ndarrays" + (.argArrays executor)) + +(defn grad-map [executor] + (util/scala-map->map (.gradDict executor))) + +(defn arg-map [executor] + (util/scala-map->map (.argDict executor))) + +(defn set-arg [executor arg-name arg-val-or-vec] + (-> executor + (arg-map) + (get arg-name) + (ndarray/set arg-val-or-vec))) + +(defn set-arg-arrays [executor vec-of-ndarray-or-val] + (doall (map (fn [arg-array v] (ndarray/set arg-array v)) (vec (arg-arrays executor)) vec-of-ndarray-or-val))) + +(defn get-grad [executor grad-name] + (-> executor + (grad-map) + (get grad-name))) + +(defn reshape + " * Return a new executor with the same symbol and shared memory, + * but different input/output shapes. + * For runtime reshaping, variable length sequences, etc. + * The returned executor shares state with the current one, + * and cannot be used in parallel with it. + * @param kwargs Map of string to shape-vec. + * - new shape for arguments. + * @parms opts with :partial-shaping Whether to allow changing the shape of unspecified arguments. + * and :allow-up-sizing Whether to allow allocating new ndarrays that's larger than the original." + ([executor kwargs {:keys [partial-shaping allow-up-sizing] + :or {partial-shaping false allow-up-sizing false}}] + (do + (let [kwargs-shapes (zipmap (keys kwargs) + (mapv (fn [v] (if (vector? v) (mx-shape/->shape v) v)) (vals kwargs)))] + (.reshape executor partial-shaping allow-up-sizing (util/convert-map kwargs-shapes))) + executor)) + ([executor kwargs] + (reshape executor kwargs {}))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/gen/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/gen/ndarray.clj new file mode 100644 index 000000000000..774b70456d91 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/gen/ndarray.clj @@ -0,0 +1,2312 @@ +(ns org.apache.clojure-mxnet.ndarray + (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max + min repeat reverse set sort take to-array empty shuffle]) + (:import (org.apache.mxnet NDArray Shape))) + +;; Do not edit - this is auto-generated + +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + + + + +(defn + div + ([ndarray num-or-ndarray] + (util/coerce-return + (.$div + ndarray + (util/coerce-param + num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + div= + ([ndarray num-or-ndarray] + (util/coerce-return + (.$div$eq + ndarray + (util/coerce-param + num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + > + ([ndarray ndarray-or-num] + (util/coerce-return + (.$greater + ndarray + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + >= + ([ndarray ndarray-or-num] + (util/coerce-return + (.$greater$eq + ndarray + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + < + ([ndarray ndarray-or-num] + (util/coerce-return + (.$less + ndarray + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + <= + ([ndarray ndarray-or-num] + (util/coerce-return + (.$less$eq + ndarray + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + - + ([ndarray ndarray-or-num] + (util/coerce-return + (.$minus + ndarray + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + -= + ([ndarray ndarray-or-num] + (util/coerce-return + (.$minus$eq + ndarray + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + % + ([ndarray num-or-ndarray] + (util/coerce-return + (.$percent + ndarray + (util/coerce-param + num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + %= + ([ndarray num-or-ndarray] + (util/coerce-return + (.$percent$eq + ndarray + (util/coerce-param + num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + + + ([ndarray ndarray-or-num] + (util/coerce-return + (.$plus + ndarray + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + += + ([ndarray num-or-ndarray] + (util/coerce-return + (.$plus$eq + ndarray + (util/coerce-param + num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + * + ([ndarray ndarray-or-num] + (util/coerce-return + (.$times + ndarray + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + *= + ([ndarray ndarray-or-num] + (util/coerce-return + (.$times$eq + ndarray + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + ** + ([ndarray num-or-ndarray] + (util/coerce-return + (.$times$times + ndarray + (util/coerce-param + num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + **= + ([ndarray ndarray-or-num] + (util/coerce-return + (.$times$times$eq + ndarray + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + activation + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Activation + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + batch-norm + ([& nd-array-and-params] + (util/coerce-return + (NDArray/BatchNorm + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + batch-norm-v1 + ([& nd-array-and-params] + (util/coerce-return + (NDArray/BatchNorm_v1 + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + bilinear-sampler + ([& nd-array-and-params] + (util/coerce-return + (NDArray/BilinearSampler + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + block-grad + ([& nd-array-and-params] + (util/coerce-return + (NDArray/BlockGrad + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + cast + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Cast + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + concat + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Concat + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + convolution + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Convolution + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + convolution-v1 + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Convolution_v1 + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + correlation + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Correlation + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + crop + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Crop + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + custom + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Custom + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + deconvolution + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Deconvolution + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + dropout + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Dropout + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + element-wise-sum + ([& nd-array-and-params] + (util/coerce-return + (NDArray/ElementWiseSum + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + embedding + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Embedding + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + flatten + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Flatten + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + fully-connected + ([& nd-array-and-params] + (util/coerce-return + (NDArray/FullyConnected + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + grid-generator + ([& nd-array-and-params] + (util/coerce-return + (NDArray/GridGenerator + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + identity-attach-kl-sparse-reg + ([& nd-array-and-params] + (util/coerce-return + (NDArray/IdentityAttachKLSparseReg + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + instance-norm + ([& nd-array-and-params] + (util/coerce-return + (NDArray/InstanceNorm + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + l2-normalization + ([& nd-array-and-params] + (util/coerce-return + (NDArray/L2Normalization + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + lrn + ([& nd-array-and-params] + (util/coerce-return + (NDArray/LRN + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + layer-norm + ([& nd-array-and-params] + (util/coerce-return + (NDArray/LayerNorm + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + leaky-re-lu + ([& nd-array-and-params] + (util/coerce-return + (NDArray/LeakyReLU + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + linear-regression-output + ([& nd-array-and-params] + (util/coerce-return + (NDArray/LinearRegressionOutput + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + logistic-regression-output + ([& nd-array-and-params] + (util/coerce-return + (NDArray/LogisticRegressionOutput + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + mae-regression-output + ([& nd-array-and-params] + (util/coerce-return + (NDArray/MAERegressionOutput + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + make-loss + ([& nd-array-and-params] + (util/coerce-return + (NDArray/MakeLoss + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + pad + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Pad + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + pooling + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Pooling + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + pooling-v1 + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Pooling_v1 + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + rnn + ([& nd-array-and-params] + (util/coerce-return + (NDArray/RNN + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + roi-pooling + ([& nd-array-and-params] + (util/coerce-return + (NDArray/ROIPooling + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + reshape + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Reshape + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + svm-output + ([& nd-array-and-params] + (util/coerce-return + (NDArray/SVMOutput + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sequence-last + ([& nd-array-and-params] + (util/coerce-return + (NDArray/SequenceLast + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sequence-mask + ([& nd-array-and-params] + (util/coerce-return + (NDArray/SequenceMask + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sequence-reverse + ([& nd-array-and-params] + (util/coerce-return + (NDArray/SequenceReverse + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + slice-channel + ([& nd-array-and-params] + (util/coerce-return + (NDArray/SliceChannel + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + softmax + ([& nd-array-and-params] + (util/coerce-return + (NDArray/Softmax + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + softmax-activation + ([& nd-array-and-params] + (util/coerce-return + (NDArray/SoftmaxActivation + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + softmax-output + ([& nd-array-and-params] + (util/coerce-return + (NDArray/SoftmaxOutput + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + spatial-transformer + ([& nd-array-and-params] + (util/coerce-return + (NDArray/SpatialTransformer + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + swap-axis + ([& nd-array-and-params] + (util/coerce-return + (NDArray/SwapAxis + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn t ([ndarray] (util/coerce-return (.T ndarray)))) + +(defn + up-sampling + ([& nd-array-and-params] + (util/coerce-return + (NDArray/UpSampling + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + abs + ([& nd-array-and-params] + (util/coerce-return + (NDArray/abs + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + adam-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/adam_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + add-n + ([& nd-array-and-params] + (util/coerce-return + (NDArray/add_n + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + arccos + ([& nd-array-and-params] + (util/coerce-return + (NDArray/arccos + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + arccosh + ([& nd-array-and-params] + (util/coerce-return + (NDArray/arccosh + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + arcsin + ([& nd-array-and-params] + (util/coerce-return + (NDArray/arcsin + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + arcsinh + ([& nd-array-and-params] + (util/coerce-return + (NDArray/arcsinh + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + arctan + ([& nd-array-and-params] + (util/coerce-return + (NDArray/arctan + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + arctanh + ([& nd-array-and-params] + (util/coerce-return + (NDArray/arctanh + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + argmax + ([& nd-array-and-params] + (util/coerce-return + (NDArray/argmax + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + argmax-channel + ([& nd-array-and-params] + (util/coerce-return + (NDArray/argmax_channel + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + argmin + ([& nd-array-and-params] + (util/coerce-return + (NDArray/argmin + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + argsort + ([& nd-array-and-params] + (util/coerce-return + (NDArray/argsort + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + at + ([ndarray num] + (util/coerce-return (.at ndarray (util/coerce-param num #{"int"}))))) + +(defn + batch-dot + ([& nd-array-and-params] + (util/coerce-return + (NDArray/batch_dot + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + batch-take + ([& nd-array-and-params] + (util/coerce-return + (NDArray/batch_take + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-add + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_add + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-axes + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_axes + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-axis + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_axis + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-div + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_div + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-equal + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_equal + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-greater + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_greater + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-greater-equal + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_greater_equal + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-hypot + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_hypot + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-lesser + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_lesser + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-lesser-equal + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_lesser_equal + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-maximum + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_maximum + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-minimum + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_minimum + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-minus + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_minus + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-mod + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_mod + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-mul + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_mul + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-not-equal + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_not_equal + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-plus + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_plus + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-power + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_power + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-sub + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_sub + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + broadcast-to + ([& nd-array-and-params] + (util/coerce-return + (NDArray/broadcast_to + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + cast + ([& nd-array-and-params] + (util/coerce-return + (NDArray/cast + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + cast-storage + ([& nd-array-and-params] + (util/coerce-return + (NDArray/cast_storage + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + cbrt + ([& nd-array-and-params] + (util/coerce-return + (NDArray/cbrt + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + ceil + ([& nd-array-and-params] + (util/coerce-return + (NDArray/ceil + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + choose-element-0index + ([& nd-array-and-params] + (util/coerce-return + (NDArray/choose_element_0index + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + clip + ([& nd-array-and-params] + (util/coerce-return + (NDArray/clip + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + concat + ([& nd-array-and-params] + (util/coerce-return + (NDArray/concat + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + concatenate + ([& nd-array-and-params] + (util/coerce-return + (NDArray/concatenate + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn context ([ndarray] (util/coerce-return (.context ndarray)))) + +(defn copy ([ndarray] (util/coerce-return (.copy ndarray)))) + +(defn + cos + ([& nd-array-and-params] + (util/coerce-return + (NDArray/cos + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + cosh + ([& nd-array-and-params] + (util/coerce-return + (NDArray/cosh + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + creation-trace + ([ndarray] (util/coerce-return (.creationTrace ndarray)))) + +(defn + crop + ([& nd-array-and-params] + (util/coerce-return + (NDArray/crop + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + degrees + ([& nd-array-and-params] + (util/coerce-return + (NDArray/degrees + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + dependencies + ([ndarray] (util/coerce-return (.dependencies ndarray)))) + +(defn + deserialize + ([byte-array] + (util/coerce-return + (NDArray/deserialize (util/coerce-param byte-array #{"byte<>"}))))) + +(defn dispose ([ndarray] (util/coerce-return (.dispose ndarray)))) + +(defn + dispose-deps + ([ndarray] (util/coerce-return (.disposeDeps ndarray)))) + +(defn + dispose-deps-except + ([ndarray & nd-array-and-params] + (util/coerce-return + (.disposeDepsExcept + ndarray + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + dot + ([& nd-array-and-params] + (util/coerce-return + (NDArray/dot + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn dtype ([ndarray] (util/coerce-return (.dtype ndarray)))) + +(defn + elemwise-add + ([& nd-array-and-params] + (util/coerce-return + (NDArray/elemwise_add + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + elemwise-div + ([& nd-array-and-params] + (util/coerce-return + (NDArray/elemwise_div + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + elemwise-mul + ([& nd-array-and-params] + (util/coerce-return + (NDArray/elemwise_mul + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + elemwise-sub + ([& nd-array-and-params] + (util/coerce-return + (NDArray/elemwise_sub + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + equal + ([ndarray-or-ndarray ndarray-or-num] + (util/coerce-return + (NDArray/equal + (util/coerce-param + ndarray-or-ndarray + #{"org.apache.mxnet.NDArray"}) + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + equals + ([ndarray Object] + (util/coerce-return + (.equals ndarray (util/coerce-param Object #{"java.lang.Object"}))))) + +(defn + exp + ([& nd-array-and-params] + (util/coerce-return + (NDArray/exp + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + expand-dims + ([& nd-array-and-params] + (util/coerce-return + (NDArray/expand_dims + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + expm1 + ([& nd-array-and-params] + (util/coerce-return + (NDArray/expm1 + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + fill-element-0index + ([& nd-array-and-params] + (util/coerce-return + (NDArray/fill_element_0index + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn finalize ([ndarray] (util/coerce-return (.finalize ndarray)))) + +(defn + fix + ([& nd-array-and-params] + (util/coerce-return + (NDArray/fix + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + flatten + ([& nd-array-and-params] + (util/coerce-return + (NDArray/flatten + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + flip + ([& nd-array-and-params] + (util/coerce-return + (NDArray/flip + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + floor + ([& nd-array-and-params] + (util/coerce-return + (NDArray/floor + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + ftml-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/ftml_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + ftrl-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/ftrl_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + gamma + ([& nd-array-and-params] + (util/coerce-return + (NDArray/gamma + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + gammaln + ([& nd-array-and-params] + (util/coerce-return + (NDArray/gammaln + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + gather-nd + ([& nd-array-and-params] + (util/coerce-return + (NDArray/gather_nd + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + greater + ([ndarray-or-ndarray ndarray-or-num] + (util/coerce-return + (NDArray/greater + (util/coerce-param + ndarray-or-ndarray + #{"org.apache.mxnet.NDArray"}) + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + greater-equal + ([ndarray-or-ndarray num-or-ndarray] + (util/coerce-return + (NDArray/greaterEqual + (util/coerce-param + ndarray-or-ndarray + #{"org.apache.mxnet.NDArray"}) + (util/coerce-param + num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn handle ([ndarray] (util/coerce-return (.handle ndarray)))) + +(defn hash-code ([ndarray] (util/coerce-return (.hashCode ndarray)))) + +(defn + identity + ([& nd-array-and-params] + (util/coerce-return + (NDArray/identity + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn internal ([ndarray] (util/coerce-return (.internal ndarray)))) + +(defn + is-disposed + ([ndarray] (util/coerce-return (.isDisposed ndarray)))) + +(defn + khatri-rao + ([& nd-array-and-params] + (util/coerce-return + (NDArray/khatri_rao + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + lesser + ([ndarray-or-ndarray ndarray-or-num] + (util/coerce-return + (NDArray/lesser + (util/coerce-param + ndarray-or-ndarray + #{"org.apache.mxnet.NDArray"}) + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + lesser-equal + ([ndarray-or-ndarray ndarray-or-num] + (util/coerce-return + (NDArray/lesserEqual + (util/coerce-param + ndarray-or-ndarray + #{"org.apache.mxnet.NDArray"}) + (util/coerce-param + ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + linalg-gelqf + ([& nd-array-and-params] + (util/coerce-return + (NDArray/linalg_gelqf + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + linalg-gemm + ([& nd-array-and-params] + (util/coerce-return + (NDArray/linalg_gemm + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + linalg-gemm2 + ([& nd-array-and-params] + (util/coerce-return + (NDArray/linalg_gemm2 + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + linalg-potrf + ([& nd-array-and-params] + (util/coerce-return + (NDArray/linalg_potrf + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + linalg-potri + ([& nd-array-and-params] + (util/coerce-return + (NDArray/linalg_potri + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + linalg-sumlogdiag + ([& nd-array-and-params] + (util/coerce-return + (NDArray/linalg_sumlogdiag + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + linalg-syrk + ([& nd-array-and-params] + (util/coerce-return + (NDArray/linalg_syrk + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + linalg-trmm + ([& nd-array-and-params] + (util/coerce-return + (NDArray/linalg_trmm + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + linalg-trsm + ([& nd-array-and-params] + (util/coerce-return + (NDArray/linalg_trsm + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + load + ([String] + (util/coerce-return + (NDArray/load (util/coerce-param String #{"java.lang.String"}))))) + +(defn + load2-array + ([String] + (util/coerce-return + (NDArray/load2Array + (util/coerce-param String #{"java.lang.String"}))))) + +(defn + load2-map + ([String] + (util/coerce-return + (NDArray/load2Map + (util/coerce-param String #{"java.lang.String"}))))) + +(defn + log + ([& nd-array-and-params] + (util/coerce-return + (NDArray/log + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + log10 + ([& nd-array-and-params] + (util/coerce-return + (NDArray/log10 + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + log1p + ([& nd-array-and-params] + (util/coerce-return + (NDArray/log1p + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + log2 + ([& nd-array-and-params] + (util/coerce-return + (NDArray/log2 + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + log-dispose-warning + ([ndarray] (util/coerce-return (.logDisposeWarning ndarray)))) + +(defn + log-softmax + ([& nd-array-and-params] + (util/coerce-return + (NDArray/log_softmax + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + make-loss + ([& nd-array-and-params] + (util/coerce-return + (NDArray/make_loss + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + max + ([& nd-array-and-params] + (util/coerce-return + (NDArray/max + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + max-axis + ([& nd-array-and-params] + (util/coerce-return + (NDArray/max_axis + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + maximum + ([num-or-ndarray-or-ndarray ndarray-or-num-or-ndarray] + (util/coerce-return + (NDArray/maximum + (util/coerce-param + num-or-ndarray-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}) + (util/coerce-param + ndarray-or-num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + mean + ([& nd-array-and-params] + (util/coerce-return + (NDArray/mean + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + min + ([& nd-array-and-params] + (util/coerce-return + (NDArray/min + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + min-axis + ([& nd-array-and-params] + (util/coerce-return + (NDArray/min_axis + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + minimum + ([ndarray-or-ndarray-or-num num-or-ndarray-or-ndarray] + (util/coerce-return + (NDArray/minimum + (util/coerce-param + ndarray-or-ndarray-or-num + #{"float" "org.apache.mxnet.NDArray"}) + (util/coerce-param + num-or-ndarray-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + mp-sgd-mom-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/mp_sgd_mom_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + mp-sgd-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/mp_sgd_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + nanprod + ([& nd-array-and-params] + (util/coerce-return + (NDArray/nanprod + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + nansum + ([& nd-array-and-params] + (util/coerce-return + (NDArray/nansum + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + negative + ([& nd-array-and-params] + (util/coerce-return + (NDArray/negative + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + norm + ([& nd-array-and-params] + (util/coerce-return + (NDArray/norm + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + normal + ([& nd-array-and-params] + (util/coerce-return + (NDArray/normal + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + not-equal + ([ndarray-or-ndarray num-or-ndarray] + (util/coerce-return + (NDArray/notEqual + (util/coerce-param + ndarray-or-ndarray + #{"org.apache.mxnet.NDArray"}) + (util/coerce-param + num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + one-hot + ([& nd-array-and-params] + (util/coerce-return + (NDArray/one_hot + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + onehot-encode + ([ndarray ndarray-1] + (util/coerce-return + (NDArray/onehotEncode + (util/coerce-param ndarray #{"org.apache.mxnet.NDArray"}) + (util/coerce-param ndarray-1 #{"org.apache.mxnet.NDArray"}))))) + +(defn + ones-like + ([& nd-array-and-params] + (util/coerce-return + (NDArray/ones_like + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + org.apache.mxnet.nd-array + ([ndarray long bool] + (util/coerce-return + (.org.apache.mxnet.NDArray + ndarray + (util/coerce-param long #{"long"}) + (util/coerce-param bool #{"boolean"}))))) + +(defn + pad + ([& nd-array-and-params] + (util/coerce-return + (NDArray/pad + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + pick + ([& nd-array-and-params] + (util/coerce-return + (NDArray/pick + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + power + ([num-or-ndarray-or-ndarray ndarray-or-num-or-ndarray] + (util/coerce-return + (NDArray/power + (util/coerce-param + num-or-ndarray-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}) + (util/coerce-param + ndarray-or-num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + +(defn + prod + ([& nd-array-and-params] + (util/coerce-return + (NDArray/prod + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + radians + ([& nd-array-and-params] + (util/coerce-return + (NDArray/radians + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + random-exponential + ([& nd-array-and-params] + (util/coerce-return + (NDArray/random_exponential + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + random-gamma + ([& nd-array-and-params] + (util/coerce-return + (NDArray/random_gamma + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + random-generalized-negative-binomial + ([& nd-array-and-params] + (util/coerce-return + (NDArray/random_generalized_negative_binomial + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + random-negative-binomial + ([& nd-array-and-params] + (util/coerce-return + (NDArray/random_negative_binomial + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + random-normal + ([& nd-array-and-params] + (util/coerce-return + (NDArray/random_normal + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + random-poisson + ([& nd-array-and-params] + (util/coerce-return + (NDArray/random_poisson + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + random-uniform + ([& nd-array-and-params] + (util/coerce-return + (NDArray/random_uniform + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + rcbrt + ([& nd-array-and-params] + (util/coerce-return + (NDArray/rcbrt + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + reciprocal + ([& nd-array-and-params] + (util/coerce-return + (NDArray/reciprocal + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + relu + ([& nd-array-and-params] + (util/coerce-return + (NDArray/relu + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + repeat + ([& nd-array-and-params] + (util/coerce-return + (NDArray/repeat + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + reshape + ([ndarray Shape-or-vec-of-ints] + (util/coerce-return + (.reshape + ndarray + (util/coerce-param + Shape-or-vec-of-ints + #{"org.apache.mxnet.Shape" "int<>"}))))) + +(defn + reshape-like + ([& nd-array-and-params] + (util/coerce-return + (NDArray/reshape_like + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + reverse + ([& nd-array-and-params] + (util/coerce-return + (NDArray/reverse + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + rint + ([& nd-array-and-params] + (util/coerce-return + (NDArray/rint + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + rmsprop-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/rmsprop_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + rmspropalex-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/rmspropalex_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + round + ([& nd-array-and-params] + (util/coerce-return + (NDArray/round + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + rsqrt + ([& nd-array-and-params] + (util/coerce-return + (NDArray/rsqrt + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sample-exponential + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sample_exponential + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sample-gamma + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sample_gamma + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sample-generalized-negative-binomial + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sample_generalized_negative_binomial + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sample-multinomial + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sample_multinomial + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sample-negative-binomial + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sample_negative_binomial + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sample-normal + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sample_normal + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sample-poisson + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sample_poisson + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sample-uniform + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sample_uniform + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + scatter-nd + ([& nd-array-and-params] + (util/coerce-return + (NDArray/scatter_nd + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn serialize ([ndarray] (util/coerce-return (.serialize ndarray)))) + +(defn + set + ([ndarray ndarray-or-num-or-vec-of-floats] + (util/coerce-return + (.set + ndarray + (util/coerce-param + ndarray-or-num-or-vec-of-floats + #{"float" "float<>" "org.apache.mxnet.NDArray"}))))) + +(defn + sgd-mom-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sgd_mom_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sgd-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sgd_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn shape ([ndarray] (util/coerce-return (.shape ndarray)))) + +(defn + shuffle + ([& nd-array-and-params] + (util/coerce-return + (NDArray/shuffle + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sigmoid + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sigmoid + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sign + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sign + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + signsgd-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/signsgd_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + signum-update + ([& nd-array-and-params] + (util/coerce-return + (NDArray/signum_update + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sin + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sin + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sinh + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sinh + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn size ([ndarray] (util/coerce-return (.size ndarray)))) + +(defn + slice-axis + ([& nd-array-and-params] + (util/coerce-return + (NDArray/slice_axis + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + slice-like + ([& nd-array-and-params] + (util/coerce-return + (NDArray/slice_like + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + smooth-l1 + ([& nd-array-and-params] + (util/coerce-return + (NDArray/smooth_l1 + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + softmax + ([& nd-array-and-params] + (util/coerce-return + (NDArray/softmax + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + softmax-cross-entropy + ([& nd-array-and-params] + (util/coerce-return + (NDArray/softmax_cross_entropy + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + softsign + ([& nd-array-and-params] + (util/coerce-return + (NDArray/softsign + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sort + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sort + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + split + ([& nd-array-and-params] + (util/coerce-return + (NDArray/split + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sqrt + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sqrt + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + square + ([& nd-array-and-params] + (util/coerce-return + (NDArray/square + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + squeeze + ([& nd-array-and-params] + (util/coerce-return + (NDArray/squeeze + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + stack + ([& nd-array-and-params] + (util/coerce-return + (NDArray/stack + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + stop-gradient + ([& nd-array-and-params] + (util/coerce-return + (NDArray/stop_gradient + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sum + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sum + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + sum-axis + ([& nd-array-and-params] + (util/coerce-return + (NDArray/sum_axis + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + swapaxes + ([& nd-array-and-params] + (util/coerce-return + (NDArray/swapaxes + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + take + ([& nd-array-and-params] + (util/coerce-return + (NDArray/take + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + tan + ([& nd-array-and-params] + (util/coerce-return + (NDArray/tan + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + tanh + ([& nd-array-and-params] + (util/coerce-return + (NDArray/tanh + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + tile + ([& nd-array-and-params] + (util/coerce-return + (NDArray/tile + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn to-array ([ndarray] (util/coerce-return (.toArray ndarray)))) + +(defn to-scalar ([ndarray] (util/coerce-return (.toScalar ndarray)))) + +(defn + topk + ([& nd-array-and-params] + (util/coerce-return + (NDArray/topk + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + tracing-enabled + ([ndarray] (util/coerce-return (.tracingEnabled ndarray)))) + +(defn + transpose + ([& nd-array-and-params] + (util/coerce-return + (NDArray/transpose + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + trunc + ([& nd-array-and-params] + (util/coerce-return + (NDArray/trunc + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn unary-- ([ndarray] (util/coerce-return (.unary_$minus ndarray)))) + +(defn + uniform + ([& nd-array-and-params] + (util/coerce-return + (NDArray/uniform + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn + wait-to-read + ([ndarray] (util/coerce-return (.waitToRead ndarray)))) + +(defn waitall ([] (util/coerce-return (NDArray/waitall)))) + +(defn + where + ([& nd-array-and-params] + (util/coerce-return + (NDArray/where + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + +(defn writable ([ndarray] (util/coerce-return (.writable ndarray)))) + +(defn + zeros-like + ([& nd-array-and-params] + (util/coerce-return + (NDArray/zeros_like + (util/coerce-param + nd-array-and-params + #{"scala.collection.Seq"}))))) + diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/gen/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/gen/symbol.clj new file mode 100644 index 000000000000..5c1efe6b453e --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/gen/symbol.clj @@ -0,0 +1,10940 @@ +(ns org.apache.clojure-mxnet.symbol + (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max + min repeat reverse set sort take to-array empty sin + get apply shuffle]) + (:require [org.apache.clojure-mxnet.util :as util]) + (:import (org.apache.mxnet Symbol))) + +;; Do not edit - this is auto-generated + +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + + + + +(defn + div + ([sym sym-or-object] + (util/coerce-return + (.$div + sym + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + div$m-dc$sp + ([sym double] + (util/coerce-return + (.$div$mDc$sp sym (util/nil-or-coerce-param double #{"double"}))))) + +(defn + div$m-fc$sp + ([sym num] + (util/coerce-return + (.$div$mFc$sp sym (util/nil-or-coerce-param num #{"float"}))))) + +(defn + div$m-ic$sp + ([sym num] + (util/coerce-return + (.$div$mIc$sp sym (util/nil-or-coerce-param num #{"int"}))))) + +(defn + > + ([sym sym-or-object] + (util/coerce-return + (.$greater + sym + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + >= + ([sym sym-or-object] + (util/coerce-return + (.$greater$eq + sym + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + >=$m-dc$sp + ([sym double] + (util/coerce-return + (.$greater$eq$mDc$sp + sym + (util/nil-or-coerce-param double #{"double"}))))) + +(defn + >=$m-fc$sp + ([sym num] + (util/coerce-return + (.$greater$eq$mFc$sp + sym + (util/nil-or-coerce-param num #{"float"}))))) + +(defn + >=$m-ic$sp + ([sym num] + (util/coerce-return + (.$greater$eq$mIc$sp sym (util/nil-or-coerce-param num #{"int"}))))) + +(defn + >$m-dc$sp + ([sym double] + (util/coerce-return + (.$greater$mDc$sp + sym + (util/nil-or-coerce-param double #{"double"}))))) + +(defn + >$m-fc$sp + ([sym num] + (util/coerce-return + (.$greater$mFc$sp sym (util/nil-or-coerce-param num #{"float"}))))) + +(defn + >$m-ic$sp + ([sym num] + (util/coerce-return + (.$greater$mIc$sp sym (util/nil-or-coerce-param num #{"int"}))))) + +(defn + < + ([sym sym-or-object] + (util/coerce-return + (.$less + sym + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + <= + ([sym sym-or-object] + (util/coerce-return + (.$less$eq + sym + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + <=$m-dc$sp + ([sym double] + (util/coerce-return + (.$less$eq$mDc$sp + sym + (util/nil-or-coerce-param double #{"double"}))))) + +(defn + <=$m-fc$sp + ([sym num] + (util/coerce-return + (.$less$eq$mFc$sp sym (util/nil-or-coerce-param num #{"float"}))))) + +(defn + <=$m-ic$sp + ([sym num] + (util/coerce-return + (.$less$eq$mIc$sp sym (util/nil-or-coerce-param num #{"int"}))))) + +(defn + <$m-dc$sp + ([sym double] + (util/coerce-return + (.$less$mDc$sp sym (util/nil-or-coerce-param double #{"double"}))))) + +(defn + <$m-fc$sp + ([sym num] + (util/coerce-return + (.$less$mFc$sp sym (util/nil-or-coerce-param num #{"float"}))))) + +(defn + <$m-ic$sp + ([sym num] + (util/coerce-return + (.$less$mIc$sp sym (util/nil-or-coerce-param num #{"int"}))))) + +(defn + - + ([sym object-or-sym] + (util/coerce-return + (.$minus + sym + (util/nil-or-coerce-param + object-or-sym + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + -$m-dc$sp + ([sym double] + (util/coerce-return + (.$minus$mDc$sp sym (util/nil-or-coerce-param double #{"double"}))))) + +(defn + -$m-fc$sp + ([sym num] + (util/coerce-return + (.$minus$mFc$sp sym (util/nil-or-coerce-param num #{"float"}))))) + +(defn + -$m-ic$sp + ([sym num] + (util/coerce-return + (.$minus$mIc$sp sym (util/nil-or-coerce-param num #{"int"}))))) + +(defn + % + ([sym sym-or-object] + (util/coerce-return + (.$percent + sym + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + %$m-dc$sp + ([sym double] + (util/coerce-return + (.$percent$mDc$sp + sym + (util/nil-or-coerce-param double #{"double"}))))) + +(defn + %$m-fc$sp + ([sym num] + (util/coerce-return + (.$percent$mFc$sp sym (util/nil-or-coerce-param num #{"float"}))))) + +(defn + %$m-ic$sp + ([sym num] + (util/coerce-return + (.$percent$mIc$sp sym (util/nil-or-coerce-param num #{"int"}))))) + +(defn + + + ([sym object-or-sym] + (util/coerce-return + (.$plus + sym + (util/nil-or-coerce-param + object-or-sym + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + +$m-dc$sp + ([sym double] + (util/coerce-return + (.$plus$mDc$sp sym (util/nil-or-coerce-param double #{"double"}))))) + +(defn + +$m-fc$sp + ([sym num] + (util/coerce-return + (.$plus$mFc$sp sym (util/nil-or-coerce-param num #{"float"}))))) + +(defn + +$m-ic$sp + ([sym num] + (util/coerce-return + (.$plus$mIc$sp sym (util/nil-or-coerce-param num #{"int"}))))) + +(defn + * + ([sym sym-or-object] + (util/coerce-return + (.$times + sym + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + *$m-dc$sp + ([sym double] + (util/coerce-return + (.$times$mDc$sp sym (util/nil-or-coerce-param double #{"double"}))))) + +(defn + *$m-fc$sp + ([sym num] + (util/coerce-return + (.$times$mFc$sp sym (util/nil-or-coerce-param num #{"float"}))))) + +(defn + *$m-ic$sp + ([sym num] + (util/coerce-return + (.$times$mIc$sp sym (util/nil-or-coerce-param num #{"int"}))))) + +(defn + ** + ([sym object-or-sym] + (util/coerce-return + (.$times$times + sym + (util/nil-or-coerce-param + object-or-sym + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + **$m-dc$sp + ([sym double] + (util/coerce-return + (.$times$times$mDc$sp + sym + (util/nil-or-coerce-param double #{"double"}))))) + +(defn + **$m-fc$sp + ([sym num] + (util/coerce-return + (.$times$times$mFc$sp + sym + (util/nil-or-coerce-param num #{"float"}))))) + +(defn + **$m-ic$sp + ([sym num] + (util/coerce-return + (.$times$times$mIc$sp sym (util/nil-or-coerce-param num #{"int"}))))) + +(defn + activation + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Activation + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (activation + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (activation + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (activation + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + batch-norm + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/BatchNorm + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (batch-norm + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (batch-norm + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (batch-norm + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + batch-norm-v1 + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/BatchNorm_v1 + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (batch-norm-v1 + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (batch-norm-v1 + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (batch-norm-v1 + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + bilinear-sampler + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/BilinearSampler + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (bilinear-sampler + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (bilinear-sampler + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (bilinear-sampler + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + block-grad + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/BlockGrad + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (block-grad + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (block-grad + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (block-grad + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + cast + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Cast + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (cast + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (cast + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (cast + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + concat + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Concat + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (concat + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (concat + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (concat + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + convolution + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Convolution + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (convolution + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (convolution + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (convolution + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + convolution-v1 + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Convolution_v1 + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (convolution-v1 + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (convolution-v1 + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (convolution-v1 + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + correlation + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Correlation + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (correlation + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (correlation + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (correlation + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + crop + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Crop + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (crop + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (crop + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (crop + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + custom + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Custom + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (custom + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (custom + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (custom + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + deconvolution + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Deconvolution + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (deconvolution + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (deconvolution + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (deconvolution + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + dropout + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Dropout + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (dropout + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (dropout + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (dropout + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + element-wise-sum + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/ElementWiseSum + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (element-wise-sum + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (element-wise-sum + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (element-wise-sum + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + embedding + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Embedding + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (embedding + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (embedding + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (embedding + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + flatten + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Flatten + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (flatten + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (flatten + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (flatten + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + fully-connected + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/FullyConnected + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (fully-connected + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (fully-connected + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (fully-connected + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + grid-generator + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/GridGenerator + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (grid-generator + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (grid-generator + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (grid-generator + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + group + ([symbol-list] + (util/coerce-return + (Symbol/Group + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}))))) + +(defn + identity-attach-kl-sparse-reg + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/IdentityAttachKLSparseReg + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (identity-attach-kl-sparse-reg + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (identity-attach-kl-sparse-reg + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (identity-attach-kl-sparse-reg + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + instance-norm + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/InstanceNorm + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (instance-norm + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (instance-norm + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (instance-norm + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + l2-normalization + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/L2Normalization + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (l2-normalization + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (l2-normalization + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (l2-normalization + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + lrn + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/LRN + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (lrn + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (lrn + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (lrn + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + layer-norm + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/LayerNorm + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (layer-norm + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (layer-norm + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (layer-norm + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + leaky-re-lu + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/LeakyReLU + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (leaky-re-lu + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (leaky-re-lu + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (leaky-re-lu + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + linear-regression-output + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/LinearRegressionOutput + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (linear-regression-output + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (linear-regression-output + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (linear-regression-output + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + logistic-regression-output + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/LogisticRegressionOutput + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (logistic-regression-output + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (logistic-regression-output + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (logistic-regression-output + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + mae-regression-output + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/MAERegressionOutput + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (mae-regression-output + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (mae-regression-output + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (mae-regression-output + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + make-loss + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/MakeLoss + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (make-loss + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (make-loss + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (make-loss + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + pad + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Pad + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (pad + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (pad + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (pad + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + pooling + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Pooling + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (pooling + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (pooling + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (pooling + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + pooling-v1 + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Pooling_v1 + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (pooling-v1 + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (pooling-v1 + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (pooling-v1 + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + rnn + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/RNN + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (rnn + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (rnn + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (rnn + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + roi-pooling + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/ROIPooling + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (roi-pooling + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (roi-pooling + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (roi-pooling + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + reshape + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Reshape + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (reshape + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (reshape + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (reshape + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + svm-output + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/SVMOutput + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (svm-output + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (svm-output + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (svm-output + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sequence-last + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/SequenceLast + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sequence-last + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sequence-last + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sequence-last + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sequence-mask + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/SequenceMask + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sequence-mask + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sequence-mask + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sequence-mask + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sequence-reverse + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/SequenceReverse + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sequence-reverse + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sequence-reverse + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sequence-reverse + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + slice-channel + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/SliceChannel + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (slice-channel + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (slice-channel + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (slice-channel + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + softmax + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/Softmax + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (softmax + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (softmax + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (softmax + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + softmax-activation + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/SoftmaxActivation + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (softmax-activation + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (softmax-activation + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (softmax-activation + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + softmax-output + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/SoftmaxOutput + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (softmax-output + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (softmax-output + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (softmax-output + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + spatial-transformer + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/SpatialTransformer + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (spatial-transformer + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (spatial-transformer + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (spatial-transformer + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + swap-axis + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/SwapAxis + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (swap-axis + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (swap-axis + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (swap-axis + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + up-sampling + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/UpSampling + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (up-sampling + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (up-sampling + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (up-sampling + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + abs + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/abs + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (abs + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (abs + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (abs + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + adam-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/adam_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (adam-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (adam-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (adam-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + add-n + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/add_n + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (add-n + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (add-n + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (add-n + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + apply + ([sym sym-name kwargs-map] + (util/coerce-return + (.apply + sym + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}))))) + +(defn + arccos + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/arccos + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (arccos + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (arccos + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (arccos + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + arccosh + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/arccosh + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (arccosh + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (arccosh + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (arccosh + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + arcsin + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/arcsin + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (arcsin + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (arcsin + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (arcsin + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + arcsinh + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/arcsinh + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (arcsinh + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (arcsinh + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (arcsinh + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + arctan + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/arctan + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (arctan + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (arctan + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (arctan + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + arctanh + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/arctanh + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (arctanh + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (arctanh + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (arctanh + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + argmax + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/argmax + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (argmax + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (argmax + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (argmax + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + argmax-channel + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/argmax_channel + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (argmax-channel + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (argmax-channel + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (argmax-channel + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + argmin + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/argmin + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (argmin + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (argmin + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (argmin + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + argsort + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/argsort + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (argsort + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (argsort + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (argsort + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + attr + ([sym sym-name] + (util/coerce-return + (.attr + sym + (util/nil-or-coerce-param sym-name #{"java.lang.String"}))))) + +(defn attr-map ([sym] (util/coerce-return (.attrMap sym)))) + +(defn + batch-dot + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/batch_dot + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (batch-dot + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (batch-dot + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (batch-dot + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + batch-take + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/batch_take + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (batch-take + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (batch-take + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (batch-take + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-add + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_add + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-add + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-add + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-add + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-axes + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_axes + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-axes + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-axes + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-axes + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-axis + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_axis + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-axis + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-axis + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-axis + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-div + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_div + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-div + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-div + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-div + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-equal + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_equal + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-equal + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-equal + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-equal + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-greater + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_greater + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-greater + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-greater + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-greater + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-greater-equal + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_greater_equal + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-greater-equal + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-greater-equal + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-greater-equal + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-hypot + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_hypot + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-hypot + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-hypot + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-hypot + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-lesser + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_lesser + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-lesser + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-lesser + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-lesser + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-lesser-equal + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_lesser_equal + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-lesser-equal + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-lesser-equal + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-lesser-equal + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-maximum + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_maximum + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-maximum + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-maximum + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-maximum + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-minimum + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_minimum + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-minimum + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-minimum + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-minimum + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-minus + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_minus + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-minus + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-minus + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-minus + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-mod + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_mod + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-mod + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-mod + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-mod + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-mul + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_mul + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-mul + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-mul + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-mul + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-not-equal + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_not_equal + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-not-equal + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-not-equal + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-not-equal + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-plus + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_plus + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-plus + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-plus + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-plus + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-power + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_power + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-power + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-power + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-power + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-sub + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_sub + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-sub + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-sub + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-sub + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + broadcast-to + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/broadcast_to + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (broadcast-to + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (broadcast-to + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (broadcast-to + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + cast + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/cast + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (cast + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (cast + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (cast + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + cast-storage + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/cast_storage + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (cast-storage + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (cast-storage + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (cast-storage + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + cbrt + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/cbrt + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (cbrt + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (cbrt + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (cbrt + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + ceil + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/ceil + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (ceil + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (ceil + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (ceil + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + choose-element-0index + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/choose_element_0index + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (choose-element-0index + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (choose-element-0index + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (choose-element-0index + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + clip + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/clip + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (clip + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (clip + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (clip + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn clone ([sym] (util/coerce-return (.clone sym)))) + +(defn + concat + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/concat + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (concat + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (concat + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (concat + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + cos + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/cos + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (cos + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (cos + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (cos + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + cosh + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/cosh + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (cosh + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (cosh + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (cosh + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + create-from-listed-symbols + ([sym-name sym-name-1 kwargs-map Symbol<> kwargs-map-1] + (util/coerce-return + (Symbol/createFromListedSymbols + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param sym-name-1 #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param Symbol<> #{"org.apache.mxnet.Symbol<>"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"}))))) + +(defn + create-from-listed-symbols-no-check + ([sym-name sym-name-1 kwargs-map Symbol<> kwargs-map-1] + (util/coerce-return + (Symbol/createFromListedSymbolsNoCheck + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param sym-name-1 #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param Symbol<> #{"org.apache.mxnet.Symbol<>"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"}))))) + +(defn + create-from-named-symbols + ([sym-name sym-name-1 kwargs-map kwargs-map-1 kwargs-map-1] + (util/coerce-return + (Symbol/createFromNamedSymbols + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param sym-name-1 #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"}))))) + +(defn + create-from-named-symbols-no-check + ([sym-name sym-name-1 kwargs-map kwargs-map-1] + (util/coerce-return + (Symbol/createFromNamedSymbolsNoCheck + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param sym-name-1 #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"}))))) + +(defn creation-trace ([sym] (util/coerce-return (.creationTrace sym)))) + +(defn + crop + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/crop + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (crop + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (crop + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (crop + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn debug-str ([sym] (util/coerce-return (.debugStr sym)))) + +(defn + degrees + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/degrees + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (degrees + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (degrees + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (degrees + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn dispose ([sym] (util/coerce-return (.dispose sym)))) + +(defn + dot + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/dot + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (dot + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (dot + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (dot + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + elemwise-add + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/elemwise_add + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (elemwise-add + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (elemwise-add + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (elemwise-add + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + elemwise-div + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/elemwise_div + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (elemwise-div + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (elemwise-div + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (elemwise-div + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + elemwise-mul + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/elemwise_mul + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (elemwise-mul + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (elemwise-mul + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (elemwise-mul + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + elemwise-sub + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/elemwise_sub + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (elemwise-sub + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (elemwise-sub + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (elemwise-sub + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + equal + ([sym-or-sym-or-object object-or-sym-or-sym] + (util/coerce-return + (Symbol/equal + (util/nil-or-coerce-param + sym-or-sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}) + (util/nil-or-coerce-param + object-or-sym-or-sym + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + exp + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/exp + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (exp + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (exp + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (exp + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + expand-dims + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/expand_dims + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (expand-dims + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (expand-dims + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (expand-dims + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + expm1 + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/expm1 + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (expm1 + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (expm1 + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (expm1 + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + fill-element-0index + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/fill_element_0index + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (fill-element-0index + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (fill-element-0index + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (fill-element-0index + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn finalize ([sym] (util/coerce-return (.finalize sym)))) + +(defn + fix + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/fix + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (fix + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (fix + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (fix + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + flatten + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/flatten + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (flatten + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (flatten + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (flatten + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + flip + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/flip + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (flip + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (flip + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (flip + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + floor + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/floor + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (floor + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (floor + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (floor + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + ftml-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/ftml_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (ftml-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (ftml-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (ftml-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + ftrl-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/ftrl_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (ftrl-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (ftrl-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (ftrl-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + gamma + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/gamma + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (gamma + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (gamma + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (gamma + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + gammaln + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/gammaln + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (gammaln + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (gammaln + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (gammaln + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + gather-nd + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/gather_nd + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (gather-nd + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (gather-nd + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (gather-nd + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + get + ([sym sym-name-or-num] + (util/coerce-return + (.get + sym + (util/nil-or-coerce-param + sym-name-or-num + #{"int" "java.lang.String"}))))) + +(defn get-internals ([sym] (util/coerce-return (.getInternals sym)))) + +(defn + greater + ([sym-or-sym sym-or-object] + (util/coerce-return + (Symbol/greater + (util/nil-or-coerce-param sym-or-sym #{"org.apache.mxnet.Symbol"}) + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + greater-equal + ([sym-or-sym sym-or-object] + (util/coerce-return + (Symbol/greaterEqual + (util/nil-or-coerce-param sym-or-sym #{"org.apache.mxnet.Symbol"}) + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn handle ([sym] (util/coerce-return (.handle sym)))) + +(defn + identity + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/identity + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (identity + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (identity + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (identity + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + infer-shape + ([sym kwargs-map-or-symbol-list] + (util/coerce-return + (.inferShape + sym + (util/nil-or-coerce-param + kwargs-map-or-symbol-list + #{"scala.collection.Seq" "scala.collection.immutable.Map"})))) + ([sym vec-or-strings vec-of-ints vec-of-ints-1] + (util/coerce-return + (.inferShape + sym + (util/nil-or-coerce-param vec-or-strings #{"java.lang.String<>"}) + (util/nil-or-coerce-param vec-of-ints #{"int<>"}) + (util/nil-or-coerce-param vec-of-ints-1 #{"int<>"}))))) + +(defn + infer-type + ([sym symbol-list-or-kwargs-map] + (util/coerce-return + (.inferType + sym + (util/nil-or-coerce-param + symbol-list-or-kwargs-map + #{"scala.collection.Seq" "scala.collection.immutable.Map"}))))) + +(defn is-disposed ([sym] (util/coerce-return (.isDisposed sym)))) + +(defn + khatri-rao + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/khatri_rao + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (khatri-rao + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (khatri-rao + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (khatri-rao + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + lesser + ([sym-or-sym sym-or-object] + (util/coerce-return + (Symbol/lesser + (util/nil-or-coerce-param sym-or-sym #{"org.apache.mxnet.Symbol"}) + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + lesser-equal + ([sym-or-sym sym-or-object] + (util/coerce-return + (Symbol/lesserEqual + (util/nil-or-coerce-param sym-or-sym #{"org.apache.mxnet.Symbol"}) + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + linalg-gelqf + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/linalg_gelqf + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (linalg-gelqf + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (linalg-gelqf + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (linalg-gelqf + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + linalg-gemm + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/linalg_gemm + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (linalg-gemm + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (linalg-gemm + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (linalg-gemm + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + linalg-gemm2 + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/linalg_gemm2 + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (linalg-gemm2 + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (linalg-gemm2 + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (linalg-gemm2 + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + linalg-potrf + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/linalg_potrf + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (linalg-potrf + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (linalg-potrf + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (linalg-potrf + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + linalg-potri + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/linalg_potri + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (linalg-potri + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (linalg-potri + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (linalg-potri + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + linalg-sumlogdiag + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/linalg_sumlogdiag + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (linalg-sumlogdiag + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (linalg-sumlogdiag + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (linalg-sumlogdiag + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + linalg-syrk + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/linalg_syrk + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (linalg-syrk + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (linalg-syrk + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (linalg-syrk + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + linalg-trmm + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/linalg_trmm + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (linalg-trmm + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (linalg-trmm + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (linalg-trmm + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + linalg-trsm + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/linalg_trsm + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (linalg-trsm + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (linalg-trsm + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (linalg-trsm + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn list-arguments ([sym] (util/coerce-return (.listArguments sym)))) + +(defn list-attr ([sym] (util/coerce-return (.listAttr sym)))) + +(defn + list-auxiliary-states + ([sym] (util/coerce-return (.listAuxiliaryStates sym)))) + +(defn list-outputs ([sym] (util/coerce-return (.listOutputs sym)))) + +(defn + load + ([sym-name] + (util/coerce-return + (Symbol/load + (util/nil-or-coerce-param sym-name #{"java.lang.String"}))))) + +(defn + load-json + ([sym-name] + (util/coerce-return + (Symbol/loadJson + (util/nil-or-coerce-param sym-name #{"java.lang.String"}))))) + +(defn + log + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/log + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (log + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (log + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (log + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + log10 + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/log10 + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (log10 + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (log10 + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (log10 + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + log1p + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/log1p + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (log1p + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (log1p + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (log1p + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + log2 + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/log2 + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (log2 + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (log2 + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (log2 + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + log-dispose-warning + ([sym] (util/coerce-return (.logDisposeWarning sym)))) + +(defn + log-softmax + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/log_softmax + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (log-softmax + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (log-softmax + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (log-softmax + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + make-loss + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/make_loss + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (make-loss + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (make-loss + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (make-loss + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + max-axis + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/max_axis + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (max-axis + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (max-axis + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (max-axis + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + mean + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/mean + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (mean + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (mean + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (mean + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + min-axis + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/min_axis + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (min-axis + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (min-axis + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (min-axis + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + mp-sgd-mom-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/mp_sgd_mom_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (mp-sgd-mom-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (mp-sgd-mom-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (mp-sgd-mom-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + mp-sgd-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/mp_sgd_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (mp-sgd-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (mp-sgd-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (mp-sgd-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + nanprod + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/nanprod + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (nanprod + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (nanprod + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (nanprod + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + nansum + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/nansum + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (nansum + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (nansum + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (nansum + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + negative + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/negative + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (negative + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (negative + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (negative + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + norm + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/norm + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (norm + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (norm + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (norm + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + normal + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/normal + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (normal + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (normal + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (normal + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + not-equal + ([sym-or-sym-or-object sym-or-object-or-sym] + (util/coerce-return + (Symbol/notEqual + (util/nil-or-coerce-param + sym-or-sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}) + (util/nil-or-coerce-param + sym-or-object-or-sym + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + one-hot + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/one_hot + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (one-hot + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (one-hot + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (one-hot + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + ones-like + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/ones_like + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (ones-like + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (ones-like + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (ones-like + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + org.apache.mxnet.symbol + ([sym long] + (util/coerce-return + (.org.apache.mxnet.Symbol + sym + (util/nil-or-coerce-param long #{"long"}))))) + +(defn + pad + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/pad + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (pad + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (pad + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (pad + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + pick + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/pick + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (pick + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (pick + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (pick + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + pow + ([sym-or-object-or-sym object-or-sym-or-sym] + (util/coerce-return + (Symbol/pow + (util/nil-or-coerce-param + sym-or-object-or-sym + #{"org.apache.mxnet.Symbol" "java.lang.Object"}) + (util/nil-or-coerce-param + object-or-sym-or-sym + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + +(defn + prod + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/prod + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (prod + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (prod + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (prod + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + radians + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/radians + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (radians + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (radians + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (radians + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + random-exponential + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/random_exponential + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (random-exponential + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (random-exponential + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (random-exponential + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + random-gamma + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/random_gamma + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (random-gamma + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (random-gamma + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (random-gamma + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + random-generalized-negative-binomial + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/random_generalized_negative_binomial + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (random-generalized-negative-binomial + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (random-generalized-negative-binomial + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (random-generalized-negative-binomial + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + random-negative-binomial + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/random_negative_binomial + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (random-negative-binomial + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (random-negative-binomial + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (random-negative-binomial + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + random-normal + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/random_normal + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (random-normal + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (random-normal + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (random-normal + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + random-poisson + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/random_poisson + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (random-poisson + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (random-poisson + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (random-poisson + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + random-uniform + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/random_uniform + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (random-uniform + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (random-uniform + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (random-uniform + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + rcbrt + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/rcbrt + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (rcbrt + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (rcbrt + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (rcbrt + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + reciprocal + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/reciprocal + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (reciprocal + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (reciprocal + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (reciprocal + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + relu + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/relu + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (relu + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (relu + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (relu + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + repeat + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/repeat + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (repeat + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (repeat + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (repeat + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + reshape + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/reshape + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (reshape + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (reshape + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (reshape + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + reshape-like + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/reshape_like + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (reshape-like + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (reshape-like + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (reshape-like + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + reverse + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/reverse + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (reverse + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (reverse + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (reverse + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + rint + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/rint + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (rint + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (rint + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (rint + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + rmsprop-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/rmsprop_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (rmsprop-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (rmsprop-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (rmsprop-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + rmspropalex-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/rmspropalex_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (rmspropalex-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (rmspropalex-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (rmspropalex-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + round + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/round + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (round + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (round + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (round + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + rsqrt + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/rsqrt + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (rsqrt + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (rsqrt + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (rsqrt + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sample-exponential + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sample_exponential + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sample-exponential + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sample-exponential + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sample-exponential + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sample-gamma + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sample_gamma + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sample-gamma + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sample-gamma + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sample-gamma + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sample-generalized-negative-binomial + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sample_generalized_negative_binomial + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sample-generalized-negative-binomial + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sample-generalized-negative-binomial + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sample-generalized-negative-binomial + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sample-multinomial + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sample_multinomial + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sample-multinomial + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sample-multinomial + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sample-multinomial + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sample-negative-binomial + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sample_negative_binomial + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sample-negative-binomial + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sample-negative-binomial + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sample-negative-binomial + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sample-normal + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sample_normal + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sample-normal + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sample-normal + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sample-normal + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sample-poisson + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sample_poisson + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sample-poisson + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sample-poisson + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sample-poisson + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sample-uniform + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sample_uniform + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sample-uniform + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sample-uniform + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sample-uniform + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + save + ([sym sym-name] + (util/coerce-return + (.save + sym + (util/nil-or-coerce-param sym-name #{"java.lang.String"}))))) + +(defn + scatter-nd + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/scatter_nd + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (scatter-nd + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (scatter-nd + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (scatter-nd + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sgd-mom-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sgd_mom_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sgd-mom-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sgd-mom-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sgd-mom-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sgd-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sgd_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sgd-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sgd-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sgd-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + shuffle + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/shuffle + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (shuffle + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (shuffle + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (shuffle + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sigmoid + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sigmoid + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sigmoid + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sigmoid + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sigmoid + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sign + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sign + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sign + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sign + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sign + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + signsgd-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/signsgd_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (signsgd-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (signsgd-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (signsgd-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + signum-update + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/signum_update + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (signum-update + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (signum-update + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (signum-update + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sin + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sin + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sin + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sin + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sin + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sinh + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sinh + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sinh + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sinh + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sinh + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + slice + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/slice + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (slice + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (slice + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (slice + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + slice-axis + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/slice_axis + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (slice-axis + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (slice-axis + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (slice-axis + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + slice-like + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/slice_like + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (slice-like + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (slice-like + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (slice-like + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + smooth-l1 + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/smooth_l1 + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (smooth-l1 + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (smooth-l1 + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (smooth-l1 + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + softmax + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/softmax + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (softmax + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (softmax + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (softmax + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + softmax-cross-entropy + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/softmax_cross_entropy + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (softmax-cross-entropy + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (softmax-cross-entropy + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (softmax-cross-entropy + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + softsign + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/softsign + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (softsign + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (softsign + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (softsign + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sort + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sort + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sort + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sort + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sort + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + split + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/split + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (split + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (split + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (split + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sqrt + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sqrt + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sqrt + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sqrt + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sqrt + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + square + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/square + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (square + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (square + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (square + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + squeeze + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/squeeze + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (squeeze + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (squeeze + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (squeeze + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + stack + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/stack + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (stack + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (stack + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (stack + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + stop-gradient + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/stop_gradient + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (stop-gradient + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (stop-gradient + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (stop-gradient + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sum + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sum + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sum + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sum + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sum + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + sum-axis + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/sum_axis + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (sum-axis + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (sum-axis + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (sum-axis + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + swapaxes + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/swapaxes + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (swapaxes + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (swapaxes + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (swapaxes + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + take + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/take + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (take + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (take + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (take + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + tan + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/tan + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (tan + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (tan + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (tan + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + tanh + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/tanh + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (tanh + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (tanh + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (tanh + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + tile + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/tile + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (tile + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (tile + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (tile + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn to-json ([sym] (util/coerce-return (.toJson sym)))) + +(defn + topk + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/topk + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (topk + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (topk + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (topk + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + tracing-enabled + ([sym] (util/coerce-return (.tracingEnabled sym)))) + +(defn + transpose + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/transpose + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (transpose + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (transpose + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (transpose + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + trunc + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/trunc + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (trunc + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (trunc + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (trunc + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + uniform + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/uniform + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (uniform + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (uniform + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (uniform + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + where + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/where + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (where + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (where + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (where + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + +(defn + zeros-like + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/zeros_like + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (zeros-like + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ([sym-name kwargs-map-or-vec-or-sym] + (zeros-like + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil))) + ([kwargs-map-or-vec-or-sym] + (zeros-like + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/initializer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/initializer.clj new file mode 100644 index 000000000000..58413c811a11 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/initializer.clj @@ -0,0 +1,57 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.initializer + (:refer-clojure :exclude [apply]) + (:import (org.apache.mxnet Uniform Normal Xavier))) + +(defn uniform + "Initialize the weight with uniform [-scale, scale] + scale - The scale of uniform distribution" + ([scale] + (new Uniform (float scale))) + ([] + (uniform 0.07))) + +(defn normal + "Initialize the weight with normal(0, sigma) + sigma - Standard deviation for gaussian distribution." + ([sigma] + (new Normal (float sigma))) + ([] + (normal 0.01))) + +(defn xavier + "Initialize the weight with Xavier or similar initialization scheme + rand-type - 'gaussian' or 'uniform' + factor-type - 'avg' 'in' or 'out' + magnitude - scale of random number range " + ([{:keys [rand-type factor-type magnitude :as opts] + :or {rand-type "uniform" + factor-type "avg" + magnitude 3}}] + (new Xavier rand-type factor-type (float magnitude))) + ([] + (xavier {}))) + +(defn apply [initializer name arr] + (let [r (.apply initializer name arr)] + arr)) + +(defn init-weight [initializer name arr] + (doto initializer + (.initWeight name arr))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj new file mode 100644 index 000000000000..2f73beb12bbf --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj @@ -0,0 +1,315 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.io + (:refer-clojure :exclude [next]) + (:require [org.apache.clojure-mxnet.base :as base] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.dtype :as dtype] + [clojure.spec.alpha :as s] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.random :as random]) + (:import (org.apache.mxnet IO DataDesc DataBatch NDArray) + (org.apache.mxnet.io ResizeIter PrefetchingIter NDArrayIter MXDataIter))) + + +(defn batches + "Convert the data-pack to a batch seq" + [data-pack] + (util/scala-iterator->seq (.toIterator data-pack))) + +(defn batch-label + "Returns the vector of ndarrays that represents the label" + [batch] + (util/scala-vector->vec (.label batch))) + +(defn batch-data + "Returns the vector of ndarrays that represents the data" + [batch] + (util/scala-vector->vec (.data batch))) + +(defn batch-index + "Returns the vector of ints that represents the index" + [batch] + (util/scala-vector->vec (.index batch))) + +(defn batch-pad + "Returns the pad of the batch" + [batch] + (.pad batch)) + +(defn iterator [data-pack] + (.iterator data-pack)) + +(defn resize-iter [iter nbatch ]) + +(defn provide-data [pack-iterator] + (->> pack-iterator + (.provideData) + (util/scala-map->map) + (mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)})))) + +(defn provide-label [pack-iterator] + (->> pack-iterator + (.provideLabel) + (util/scala-map->map) + (mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)})))) + +(defn reset [iterator] + (.reset iterator)) + +(defn has-next? [iterator] + (.hasNext iterator)) + +(defn next [iterator] + (.next iterator)) + +(defn iter-label [iterator] + (util/scala-vector->vec (.getLabel iterator))) + +(defn iter-data [iterator] + (util/scala-vector->vec (.getData iterator))) + +(defn iter-init-label [iterator] + (util/scala-vector->vec (.initLabel iterator))) + +(defmacro do-batches [iter f] + "Takes an iterator and a function of one argument. The iterator will be reset and run thhrough all the batches with the batch passed to the function argument. nil is returned" + `(do + (reset ~iter) + (loop [it# ~iter] + (when (has-next? it#) + (let [b# (next it#)] + (do (~f b#)) + (recur it#)))))) + +(defmacro for-batches + "Takes an iterator and a function of one argument. The iterator will be reset and run thhrough all the batches with the batch passed to the function argument. The result of the function will be conjed to a vector result of all the batches and returned at the end." + [iter f] + `(do + (reset ~iter) + (loop [it# ~iter + result# []] + (if (has-next? it#) + (let [b# (next it#)] + (recur it# (conj result# (do (~f b#))))) + result#)))) + +(defmacro reduce-batches + "Takes an iterator and a function of two arguments. The iterator will be reset and run thhrough all the batches with the batch passed to the function argument. The result of the function will the result of the reduce function" + ([iter f initial-val] + `(do + (reset ~iter) + (loop [it# ~iter + result# ~initial-val] + (if (has-next? it#) + (let [b# (next it#) + r# (do (~f result# b#))] + (recur it# r#)) + result#)))) + ([iter f] + `(reduce-batches ~iter ~f 0))) + +(defn + csv-iter + ([kwargs] + (util/apply-scala-fn (IO/CSVIter) (util/convert-io-map kwargs)))) + +(defn + csv-pack + ([kwargs] + (util/apply-scala-fn (IO/CSVPack) (util/convert-io-map kwargs)))) + +(defn + image-recode-pack + ([kwargs] + (util/apply-scala-fn + (IO/ImageRecodePack) + (util/convert-io-map kwargs)))) + +(defn + image-record-iter + ([kwargs] + (util/apply-scala-fn + (IO/ImageRecordIter) + (util/convert-io-map kwargs)))) + +(defn + mnist-iter + ([kwargs] + (util/apply-scala-fn (IO/MNISTIter) (util/convert-io-map kwargs)))) + +(defn + mnist-pack + ([kwargs] + (util/apply-scala-fn (IO/MNISTPack) (util/convert-io-map kwargs)))) + +(defn + create-iterator + ([iter-name kwargs-map] + (util/coerce-return (IO/createIterator iter-name (util/convert-io-map kwargs-map))))) + +(defn + create-mx-data-pack + ([pack-name kwargs-map] + (util/coerce-return (IO/createMXDataPack pack-name (util/convert-io-map kwargs-map))))) + +(defn resize-iter + "* Resize a data iterator to given number of batches per epoch. + * May produce incomplete batch in the middle of an epoch due + * to padding from internal iterator. + * + * @param data-iter Internal data iterator. + * @param resize number of batches per epoch to resize to. + * @param reset-internal whether to reset internal iterator with reset" + [data-iter resize reset-iternal] + (new ResizeIter data-iter resize reset-iternal)) + +(defn prefetching-iter + "Takes one or more data iterators and combines them with pre-fetching" + [iters data-names label-names] + (new PrefetchingIter + (util/vec->indexed-seq iters) + (->> data-names + (mapv util/convert-map) + (util/vec->indexed-seq)) + (->> label-names + (mapv util/convert-map) + (util/vec->indexed-seq)))) + +(defn ndarray-iter + " * NDArrayIter object in mxnet. Taking NDArray to get dataiter. + * + * @param data vector of iter + * @opts map of: + * :label Same as data, but is not fed to the model during testing. + * :data-batch-size Batch Size (default 1) + * :shuffle Whether to shuffle the data (default false) + * :last-batch-handle = pad, discard, or rollover. (default pad) + * :data-name String of data name (default data) + * :label-name String of label name (default label) + * How to handle the last batch + * This iterator will pad, discard or roll over the last batch if + * the size of data does not match batch-size. Roll over is intended + * for training and can cause problems if used for prediction." + ([data {:keys [label data-batch-size shuffle last-batch-handle data-name label-name] :as opts + :or {label nil + data-batch-size 1 + shuffle false + last-batch-handle "pad" + data-name "data" + label-name "label"}}] + (new NDArrayIter + (util/vec->indexed-seq data) + (if label (util/vec->indexed-seq label) (util/empty-indexed-seq)) + (int data-batch-size) + shuffle + last-batch-handle + data-name + label-name)) + ([data] + (ndarray-iter data {}))) + +(defn dispose [iterator] + (.dispose iterator)) + +(s/def ::name string?) +(s/def ::shape vector?) +(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64}) +(s/def ::data-desc (s/keys :req-un [::name ::shape] :opt-un [::dtype ::layout])) + + +;; NCHW is N:batch size C: channel H: height W: width +;;; other layouts are +;; NT, TNC, nad N +;; the shape length must match the lengh of the layout string size +(defn data-desc + ([{:keys [name shape dtype layout] :as opts + :or {dtype base/MX_REAL_TYPE}}] + (util/validate! ::data-desc opts "Invalid data description") + (let [sc (count shape) + layout (or layout (cond + (= 1 sc) "N" + (= 2 sc) "NT" + (= 3 sc) "TNC" + (= 4 sc) "NCHW" + :else (apply str (repeat sc "?"))))] + (new DataDesc name (mx-shape/->shape shape) dtype layout))) + ([name shape] + (data-desc {:name name :shape shape}))) + +(s/def ::ndarray #(instance? NDArray %)) +(s/def ::data vector?) +(s/def ::label (s/nilable (s/coll-of ::ndarray :kind vector?))) +(s/def ::index (s/nilable (s/coll-of int? :kind vector?))) +(s/def ::pad integer?) +(s/def ::bucket-key string?) +(s/def ::provided-data ::data-desc) +(s/def ::provided-label ::data-desc) +(s/def ::data-batch-class #(instance? DataBatch %)) + +(s/def ::data-batch + (s/or + :data-batch-class + ::data-batch-class + :data-batch-map + (s/keys :req-un [::data] :opt-un [::label ::index ::pad ::bucket-key ::provided-data ::provided-label]))) + +(defn data-batch + [{:keys [data label index pad bucket-key provided-data provided-label] :as info + :or {data [] label [] index [] pad 0}}] + ;; provided-data and provided label is a map of name to shape to indicate the order of the data/label loading + (util/validate! ::data-batch info "Invalid data batch") + (new DataBatch + (util/vec->indexed-seq data) + (util/vec->indexed-seq label) + (util/vec->indexed-seq index) + (int pad) + bucket-key + (when provided-data (util/list-map provided-data)) + (when provided-label(util/list-map provided-label)))) + +(defn rand-iter + "A implementation of a random noise iterator + Instead of data pass in the shape vector of the noise shape" + ([shape-vec {:keys [label data-batch-size shuffle last-batch-handle data-name label-name] :as opts + :or {label nil + data-batch-size 1 + shuffle false + last-batch-handle "pad" + data-name "rand" + label-name "label"}}] + (let [data [(ndarray/ones shape-vec)]] + (proxy [NDArrayIter] + [(util/vec->indexed-seq data) + (if label (util/vec->indexed-seq label) (util/empty-indexed-seq)) + (int data-batch-size) + shuffle + last-batch-handle + data-name + label-name] + (provideData [] + (util/list-map {data-name (mx-shape/->vec (ndarray/shape (first data)))})) + (provideLabel [] (util/empty-list-map)) + (hasNext [] true) + (getData + ([] (util/vec->indexed-seq [(random/normal 0 1 (mx-shape/->vec (ndarray/shape (first data))))]))) + (getLabel + ([] (util/vec->indexed-seq [])))))) + ([shape-vec] + (rand-iter shape-vec {}))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/kvstore.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/kvstore.clj new file mode 100644 index 000000000000..574bd77321e7 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/kvstore.clj @@ -0,0 +1,205 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.kvstore + (:refer-clojure :exclude [type]) + (:require [clojure.spec.alpha :as spec] + [org.apache.clojure-mxnet.util :as util] + [clojure.spec.alpha :as s]) + (:import (org.apache.mxnet KVStore NDArray))) + +(defn create + " Create a new KVStore + WARNING: it is your responsibility to clear this object through dispose. + - name : #{local, dist} (default is local) + The type of KVStore + - local works for multiple devices on a single machine (single process) + - dist works for multi-machines (multiple processes)" + ([name] + (KVStore/create name)) + ([] + (create "local"))) + +(defn dispose + "Release the native memory. + The object shall never be used after it is disposed." + [kvstore] + (.dispose kvstore)) + + +(s/def ::ks (s/or :string string? + :vec-of-string (s/coll-of string? :kind vector?))) +(s/def ::ndarray #(instance? NDArray %)) +(s/def ::vs (s/or :ndarray ::ndarray + :vec-of-ndarray (s/coll-of ::ndarray :kind vector?))) + +(defn init + "Initialize a single or a sequence of key-value pairs into the store. + For each key, one must init it before push and pull. + Only worker 0's (rank == 0) data are used. + This function returns after data have been initialized successfully + kvstore - KVstore + ks - keys (vec or strings or single string) + vs - values (vec or NDArrays or single ndarry)" + [kvstore ks vs] + (util/validate! ::ks ks "Invalid keys") + (util/validate! ::vs vs "Invalid values") + (doto kvstore + (.init (into-array (if (vector? ks) ks [ks])) + (into-array (if (vector? vs) vs [vs]))))) + +(s/def ::priority int?) + +(defn push + " Push a single or a sequence of key-value pairs into the store. + Data consistency: + 1. this function returns after adding an operator to the engine. + 2. push is always called after all previous push and pull on the same key are finished + 3. there is no synchronization between workers. One can use _barrier() to sync all workers + + -ks Keys + -vs values According values + - priority + The priority of the push operation. + The higher the priority, the faster this action is likely + to be executed before other push actions." + ([kvstore ks vs priority] + (util/validate! ::ks ks "Invalid keys") + (util/validate! ::vs vs "Invalid values") + (util/validate! ::priority priority "Invalid priority") + (let [store-vals (if (vector? vs) vs [vs]) + store-keys (if (vector? ks) ks (into [] (repeat (count store-vals) ks)))] + (doto kvstore + (.push (into-array store-keys) + (into-array store-vals) + (int priority))))) + ([kvstore ks vs] + (push kvstore ks vs 0))) + +(s/def ::outs (s/or :ndarray ::ndarray + :vec-of-ndarray (s/coll-of ::ndarray :kind vector?))) + +(defn pull + " Pull a single value or a sequence of values from the store. + Data consistency: + 1. this function returns after adding an operator to the engine. But any + further read on out will be blocked until it is finished. + 2. pull is always called after all previous push and pull on the same key are finished + 3. It pulls the newest value from the store. + - kvstore + - ks single or vector of (strings) + - outs single or vector of outs (NDArrays) + - priority + The priority of the push operation. + The higher the priority, the faster this action is likely + to be executed before other push actions." + ([kvstore ks outs priority] + (util/validate! ::ks ks "Invalid keys") + (util/validate! ::outs outs "Invalid outs") + (util/validate! ::priority priority "Invalid priority") + (let [store-vals (if (vector? outs) outs [outs]) + store-keys (if (vector? ks) ks (into [] (repeat (count store-vals) ks)))] + (doto kvstore + (.pull (into-array store-keys) + (into-array store-vals) + (int priority))))) + ([kvstore ks outs] + (pull kvstore ks outs 0))) + +(defn type + "Get the type of the kvstore" + [kvstore] + (.type kvstore)) + +(defn num-workers + "Get the number of worker nodes" + [kvstore] + (.numWorkers kvstore)) + +(defn rank + "Get the rank of this worker node + returns The rank of this node, which is in [0, get_num_workers()) " + [kvstore] + (.rank kvstore)) + +(defn set-optimizer + "Register an optimizer to the store + If there are multiple machines, this process (should be a worker node) + will pack this optimizer and send it to all servers. It returns after + this action is done" + [kvstore optimizer] + (doto kvstore + (.setOptimizer optimizer))) + +(defn barrier + "Global barrier among all worker nodes + For example, assume there are n machines, we want to let machine 0 first + init the values, and then pull the inited value to all machines. Before + pulling, we can place a barrier to guarantee that the initialization is + finished." + [kvstore] + (doto kvstore + (.barrier kvstore))) + +(defn num-dead-node [kvstore node-id] + (.numDeadNode kvstore (int node-id))) + +(defn set-barrier-before-exit + " Whether to do barrier when the kvstore finalizes + - kvstore + - barrier-before-exit boolean" + [kvstore barrier-before-exit] + (doto kvstore + (.setBarrierBeforeExit barrier-before-exit))) + +(s/def ::head int?) +(s/def ::body string?) + +(defn send-command-to-servers + "Send a command to all server nodes + Send a command to all server nodes, which will make each server node run + KVStoreServer.controller + This function returns after the command has been executed in all server nodes + -kvstore + -head the head of the command + - body the body of the command" + [kvstore head body] + (util/validate! ::head head "Invalid head") + (util/validate! ::body body "Invalid body") + (doto kvstore + (.sendCommandToServers (int head) body))) + + +(s/def ::fname string?) + +(defn save-optimizer-states + "Save optimizer (updater) state to file + - kvstore + - fname Path to output states file." + [kvstore fname] + (util/validate! ::fname fname "Invalid filename") + (doto kvstore + (.saveOptimizerStates fname))) + +(defn load-optimizer-states + "Load optimizer (updater) state from file + - kvstore + -fname Path to input states file." + [kvstore fname] + (util/validate! ::fname fname "Invalid filename") + (doto kvstore + (.loadOptimizerStates fname))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/kvstore_server.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/kvstore_server.clj new file mode 100644 index 000000000000..7fa116d45a91 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/kvstore_server.clj @@ -0,0 +1,39 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.kvstore-server + (:require [clojure.spec.alpha :as spec] + [org.apache.clojure-mxnet.util :as util] + [clojure.spec.alpha :as s]) + (:import (org.apache.mxnet KVStoreServer))) + + +(s/def ::env-map (s/map-of string? string?)) + +(defn init [env-map] + (util/validate! ::env-map env-map "Invalid environment map") + (KVStoreServer/init (util/convert-map env-map))) + + +(s/def ::die-if-others-go-out-timeout int?) + +(defn start + ([die-if-others-go-out-timeout] + (util/validate! ::die-if-others-go-out-timeout die-if-others-go-out-timeout "Invalid setting") + (KVStoreServer/start die-if-others-go-out-timeout)) + ([] + (start 0))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/lr_scheduler.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/lr_scheduler.clj new file mode 100644 index 000000000000..d08c40e33542 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/lr_scheduler.clj @@ -0,0 +1,28 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.lr-scheduler + (:import (org.apache.mxnet FactorScheduler))) + + +(defn factor-scheduler + "Assume the weight has been updated by n times, then the learning rate will + be base_lr * factor^^(floor(n/step)) + - step int, schedule learning rate after n updates + - factor number, the factor for reducing the learning rate" + [step factor] + (new FactorScheduler (int step) (float factor))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj new file mode 100644 index 000000000000..42d206a2dc2a --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj @@ -0,0 +1,691 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.module + (:refer-clojure :exclude [update symbol]) + (:require [org.apache.clojure-mxnet.callback :as callback] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.initializer :as initializer] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util] + [clojure.java.io :as io] + [clojure.spec.alpha :as s] + [org.apache.clojure-mxnet.ndarray :as ndarray]) + (:import (org.apache.mxnet.module Module FitParams BaseModule) + (org.apache.mxnet.io MXDataIter NDArrayIter) + (org.apache.mxnet Initializer Optimizer NDArray DataBatch + Context EvalMetric Monitor Callback$Speedometer DataDesc))) + + +(defn module + "Module is a basic module that wrap a symbol. + sym : Symbol definition. + map of options + :data-names - Input data names. + :label-names - Input label names + :contexts - Default is cpu(). + :workload-list - Default nil, indicating uniform workload. + :fixed-param-names Default nil, indicating no network parameters are fixed." + ([sym {:keys [data-names label-names contexts workload-list fixed-param-names] :as opts + :or {data-names ["data"] + label-names ["softmax_label"] + contexts [(context/default-context)]}}] + (new Module + sym + (util/vec->indexed-seq data-names) + (util/vec->indexed-seq label-names) + (into-array contexts) + (util/->option (when workload-list (util/vec->indexed-seq workload-list))) + (util/->option (when fixed-param-names (util/vec->set fixed-param-names))))) + ([sym data-names label-names contexts] + (module sym {:data-names data-names :label-names label-names :contexts contexts})) + ([sym] + (module sym {}))) + +(defn data-names [mod] + (.dataNames mod)) + +(defn data-shapes [mod] + (.dataShapes mod)) + +(defn label-shapes [mod] + (.labelShapes mod)) + +(defn output-names [mod] + (.outputNames mod)) + +(defn output-shapes [mod] + (.outputShapes mod)) + +(s/def ::data-shapes (s/coll-of ::mx-io/data-desc)) +(s/def ::label-shapes (s/coll-of ::mx-io/data-desc)) +(s/def ::for-training boolean?) +(s/def ::inputs-need-grad boolean?) +(s/def ::force-rebind boolean?) +(s/def ::shared-module #(instance? Module)) +(s/def ::grad-req string?) +(s/def ::bind-opts (s/keys :req-un [::data-shapes] :opt-un [::label-shapes ::for-training ::inputs-need-grad + ::force-rebind ::shared-module ::grad-req])) + +(defn bind + "Bind the symbols to construct executors. This is necessary before one + can perform computation with the module. + mod : module + map of opts: + :data-shapes Typically is (provide-data data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout + :label-shapes Typically is (provide-label data-iter). map of :name :shape :dtype and :layout + :for-training Default is `true`. Whether the executors should be bind for training. + :inputs-need-grad Default is `false`. + Whether the gradients to the input data need to be computed. + Typically this is not needed. + But this might be needed when implementing composition of modules. + :force-rebind Default is `false`. + This function does nothing if the executors are already binded. + But with this `true`, the executors will be forced to rebind. + :shared-module Default is nil. This is used in bucketing. + When not `None`, the shared module essentially corresponds to + a different bucket -- a module with different symbol + but with the same sets of parameters + (e.g. unrolled RNNs with different lengths). " + [mod {:keys [data-shapes label-shapes for-training inputs-need-grad force-rebind + shared-module grad-req] :as opts + :or {for-training true + inputs-need-grad false + force-rebind false + grad-req "write"}}] + (util/validate! ::bind-opts opts "Incorrect bind options") + (doto mod + (.bind + (->> data-shapes + (map mx-io/data-desc) + (util/vec->indexed-seq)) + (util/->option (some->> label-shapes + (map mx-io/data-desc) + (util/vec->indexed-seq))) + for-training + inputs-need-grad + force-rebind + (util/->option shared-module) + grad-req))) + +(s/def ::intializer #(instance? Initializer %)) +(s/def ::arg-params map?) +(s/def ::aux-params map?) +(s/def ::force-init boolean?) +(s/def ::allow-extra boolean?) +(s/def ::init-params-opts (s/keys :opt-un [::initializer ::arg-params ::aux-params + ::force-init ::allow-extra])) + +(defn init-params + " Initialize the parameters and auxiliary states. + options map + :initializer - Called to initialize parameters if needed. + :arg-params - If not nil, should be a map of existing arg-params. + Initialization will be copied from that. + :auxParams - If not nil, should be a map of existing aux-params. + Initialization will be copied from that. + :allow-missing - If true, params could contain missing values, + and the initializer will be called to fill those missing params. + :force-init - If true, will force re-initialize even if already initialized. + :allow-extra - Whether allow extra parameters that are not needed by symbol. + If this is True, no error will be thrown when argParams or auxParams + contain extra parameters that is not needed by the executor." + ([mod {:keys [initializer arg-params aux-params allow-missing force-init allow-extra] :as opts + :or {initializer (initializer/uniform 0.01) + allow-missing false + force-init false + allow-extra false}}] + (util/validate! ::init-params-opts opts "Invalid init-params opts") + (doto mod + (.initParams + initializer + (some-> arg-params (util/convert-map)) + (some-> aux-params (util/convert-map)) + allow-missing + force-init + allow-extra))) + ([mod] + (init-params mod {}))) + +(s/def ::optimizer #(instance? Optimizer %)) +(s/def ::kvstore string?) +(s/def ::reset-optimizer boolean?) +(s/def ::force-init boolean?) +(s/def ::init-optimizer-opts (s/keys :opt-un [::optimizer ::kvstore ::reset-optimizer ::force-init])) + +(defn init-optimizer + " Install and initialize optimizers. + - mod Module + - options map of + - kvstore + - reset-optimizer Default `True`, indicating whether we should set + `rescaleGrad` & `idx2name` for optimizer according to executorGroup + - force-init Default `False`, indicating whether we should force + re-initializing the optimizer in the case an optimizer is already installed." + ([mod {:keys [kvstore optimizer reset-optimizer force-init] :as opts + :or {kvstore "local" + optimizer (optimizer/sgd) + reset-optimizer true + force-init false}}] + (util/validate! ::init-optimizer-opts opts "Invalid init-optimizer options") + (doto mod + (.initOptimizer kvstore optimizer reset-optimizer force-init))) + ([mod] + (init-optimizer mod {}))) + + +(defn forward + "Forward computation. + data-batch - input data of form io/data-batch either map or DataBatch + is-train - Default is nil, which means `is_train` takes the value of `for_training`." + ([mod data-batch is-train] + (util/validate! ::mx-io/data-batch data-batch "Invalid data batch") + (doto mod + (.forward + (if (map? data-batch) + (mx-io/data-batch data-batch) + data-batch) + (util/->option is-train)))) + ([mod data-batch-map] + (forward mod data-batch-map nil))) + + +(s/def ::ndarray #(instance? NDArray %)) +(s/def ::out-grads (s/nilable (s/coll-of ::ndarray))) + +(defn backward + "Backward computation. + out-grads - Gradient on the outputs to be propagated back. + This parameter is only needed when bind is called + on outputs that are not a loss function." + ([mod out-grads] + (util/validate! ::out-grads out-grads "Invalid out-grads") + (doto mod + (.backward (some-> out-grads into-array)))) + ([mod] + (backward mod nil))) + +(defn forward-backward + "A convenient function that calls both `forward` and `backward`." + [mod data-batch] + (util/validate! ::mx-io/data-batch data-batch "Invalid data-batch") + (doto mod + (.forwardBackward data-batch))) + +(defn outputs + " Get outputs of the previous forward computation. + In the case when data-parallelism is used, + the outputs will be collected from multiple devices. + The results will look like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`, + those `NDArray` might live on different devices." + [mod] + (->> (.getOutputs mod) + (util/scala-vector->vec) + (mapv util/scala-vector->vec))) + +(defn update + "Update parameters according to the installed optimizer and the gradients computed + in the previous forward-backward batch." + [mod] + (doto mod + (.update))) + +(defn outputs-merged + " Get outputs of the previous forward computation. + return In the case when data-parallelism is used, + the outputs will be merged from multiple devices, + as they look like from a single executor. + The results will look like `[out1, out2]`" + [mod] + (->> (.getOutputsMerged mod) + (util/scala-vector->vec))) + +(defn input-grads + " Get the gradients to the inputs, computed in the previous backward computation. + In the case when data-parallelism is used, + the outputs will be collected from multiple devices. + The results will look like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]` + those `NDArray` might live on different devices." + [mod] + (->> (.getInputGrads mod) + (util/scala-vector->vec) + (mapv util/scala-vector->vec))) + +(defn input-grads-merged + " Get the gradients to the inputs, computed in the previous backward computation. + return In the case when data-parallelism is used, + the outputs will be merged from multiple devices, + as they look like from a single executor. + The results will look like `[grad1, grad2]`" + [mod] + (->> (.getInputGradsMerged mod) + (util/scala-vector->vec))) + +(s/def ::prefix string?) +(s/def ::epoch int?) +(s/def ::save-opt-states boolean?) +(s/def ::save-checkpoint-opts (s/keys :req-un [::prefix ::epoch] :opt-un [::save-opt-states ::save-checkpoint])) + +(defn save-checkpoint + " Save current progress to checkpoint. + Use mx.callback.module_checkpoint as epoch_end_callback to save during training. + - mod Module + - opt-map with + :prefix The file prefix to checkpoint to + :epoch The current epoch number + :save-opt-states Whether to save optimizer states for continue training " + ([mod {:keys [prefix epoch save-opt-states] :as opts + :or {save-opt-states false}}] + (util/validate! ::save-checkpoint-opts opts "Invalid save checkpoint opts") + (doto mod + (.saveCheckpoint prefix (int epoch) save-opt-states))) + ([mod prefix epoch] + (save-checkpoint mod {:prefix prefix :epoch epoch}))) + + +(s/def ::load-optimizer-states boolean?) +(s/def ::data-names (s/coll-of string? :kind vector?)) +(s/def ::label-names (s/coll-of string? :kind vector?)) +(s/def ::context #(instance? Context %)) +(s/def ::contexts (s/coll-of ::context :kind vector?)) +(s/def ::workload-list (s/coll-of number? :kind vector?)) +(s/def ::fixed-params-names (s/coll-of string? :kind vector?)) +(s/def ::load-checkpoint-opts (s/keys :req-un [::prefix ::epoch] + :opt-un [::load-optimizer-states ::data-names ::label-names + ::contexts ::workload-list ::fixed-param-names])) + +(defn load-checkpoint + "Create a model from previously saved checkpoint. + - mod module + - opts map of + - prefix Path prefix of saved model files. You should have prefix-symbol.json, + prefix-xxxx.params, and optionally prefix-xxxx.states, + where xxxx is the epoch number. + - epoch Epoch to load. + - load-optimizer-states Whether to load optimizer states. + Checkpoint needs to have been made with save-optimizer-states=True + - dataNames Input data names. + - labelNames Input label names + - contexts Default is cpu(). + - workload-list Default nil, indicating uniform workload. + - fixed-param-names Default nil, indicating no network parameters are fixed." + ([{:keys [prefix epoch load-optimizer-states data-names label-names contexts + workload-list fixed-param-names] :as opts + :or {load-optimizer-states false + data-names ["data"] + label-names ["softmax_label"] + contexts [(context/cpu)] + workload-list nil + fixed-param-names nil}}] + (util/validate! ::load-checkpoint-opts opts "Invalid load-checkpoint opts") + (Module/loadCheckpoint + prefix + (int epoch) + load-optimizer-states + (util/vec->indexed-seq data-names) + (util/vec->indexed-seq label-names) + (into-array contexts) + (util/->option (when workload-list (util/vec->indexed-seq workload-list))) + (util/->option (when fixed-param-names (util/vec->set fixed-param-names))))) + ([prefix epoch] + (load-checkpoint mod {:prefix prefix :epoch epoch}))) + +(defn load-optimizer-states [mod fname] + (.mod load fname)) + +(defn symbol [mod] + (.getSymbol mod)) + +(defn params [mod] + (map util/scala-map->map (util/coerce-return (.getParams mod)))) + +(defn arg-params [mod] + (util/scala-map->map (.argParams mod))) + +(defn aux-params [mod] + (util/scala-map->map (.auxParams mod))) + +(defn reshape + " Reshapes the module for new input shapes. + - mod module + - data-shapes Typically is `(provide-data data-iter) + - param label-shapes Typically is `(provide-label data-tier)`. " + ([mod data-shapes label-shapes] + (util/validate! ::data-shapes data-shapes "Invalid data-shapes") + (util/validate! (s/nilable ::label-shapes) label-shapes "Invalid label-shapes") + (doto mod + (.reshape + (->> data-shapes + (map mx-io/data-desc) + (util/vec->indexed-seq)) + (util/->option (some->> label-shapes + (map mx-shape/->shape) + (util/vec->indexed-seq)))))) + ([mod data-shapes] + (reshape mod data-shapes nil))) + +(s/def ::set-param-opts (s/keys :opt-un [::arg-params ::aux-params ::allow-missing ::force-init ::allow-extra])) + +(defn get-params [mod] + (.getParams mod)) + +(defn set-params + " Assign parameter and aux state values. + - mod module + - arg-params : map + map of name to value (`NDArray`) mapping. + - aux-params : map + map of name to value (`NDArray`) mapping. + - allow-missing : bool + If true, params could contain missing values, and the initializer will be + called to fill those missing params. + - force-init : bool + If true, will force re-initialize even if already initialized. + - allow-extra : bool + Whether allow extra parameters that are not needed by symbol. + If this is True, no error will be thrown when arg-params or aux-params + contain extra parameters that is not needed by the executor." + [mod {:keys [arg-params aux-params allow-missing force-init allow-extra] :as opts + :or {allow-missing false force-init true allow-extra false}}] + (util/validate! ::set-param-opts opts "Invalid set-params") + (doto mod + (.setParams + (util/convert-symbol-map arg-params) + (when aux-params (util/convert-symbol-map aux-params)) + allow-missing + force-init + allow-extra))) + +(defn install-monitor + "Install monitor on all executors" + [mod monitor] + (doto mod + (.installMonitor monitor))) + + +(defn borrow-optimizer + "Borrow optimizer from a shared module. Used in bucketing, where exactly the same + optimizer (esp. kvstore) is used. + - mod module + - shared-module" + [mod shared-module] + (doto mod + (.borrowOptimizer shared-module))) + +(defn save-optimizer-states + "Save optimizer (updater) state to file + - mod module + - fname Path to output states file." + [mod fname] + (doto mod + (.saveOptimizerStates mod fname))) + +(defn load-optimizer-states + "Load optimizer (updater) state from file + - mod module + - fname Path to input states file. + " + [mod fname] + (doto mod + (.loadOptimzerStates fname))) + + +(s/def ::eval-metric #(instance? EvalMetric %)) +(s/def ::labels (s/coll-of ::ndarray :kind vector?)) + + +(defn update-metric + "Evaluate and accumulate evaluation metric on outputs of the last forward computation. + - mod module + - eval-metric + - labels" + [mod eval-metric labels] + (util/validate! ::eval-metric eval-metric "Invalid eval metric") + (util/validate! ::labels labels "Invalid labels") + (doto mod + (.updateMetric eval-metric (util/vec->indexed-seq labels)))) + + +(s/def ::begin-epoch int?) +(s/def ::validation-metric ::eval-metric) +(s/def ::monitor #(instance? Monitor %)) +(s/def ::batch-end-callback #(instance? Callback$Speedometer %)) +(s/def ::fit-params-opts (s/keys :opt-un [::eval-metric ::kvstore ::optimizer ::initializer + ::arg-params ::aux-params ::allow-missing ::force-rebind + ::force-init ::begin-epoch ::validation-metric ::monitor + ::batch-end-callback])) + + +;; callbacks are not supported for now +(defn fit-params + "Fit Params" + ([{:keys [eval-metric kvstore optimizer + initializer arg-params aux-params + allow-missing force-rebind force-init begin-epoch validation-metric monitor + batch-end-callback] :as opts + :or {eval-metric (eval-metric/accuracy) + kvstore "local" + optimizer (optimizer/sgd) + initializer (initializer/uniform 0.01) + allow-missing false + force-rebind false + force-init false + begin-epoch 0}}] + (util/validate! ::fit-params-opts opts "Invalid fit param opts") + (doto (new FitParams) + (.setEvalMetric eval-metric) + (.setKVStore kvstore) + (.setOptimizer optimizer) + (.setInitializer initializer) + (.setArgParams (some-> arg-params (util/convert-map))) + (.setAuxParams (some-> aux-params (util/convert-map))) + (.setAllowMissing allow-missing) + (.setForceRebind force-rebind) + (.setForceInit force-init) + (.setBeginEpoch (int begin-epoch)) + (.setValidationMetric validation-metric) + (.setMonitor monitor) + (.setBatchEndCallback batch-end-callback))) + ([] + (new FitParams))) + + +(s/def ::mx-data-iter #(instance? MXDataIter %)) +(s/def ::ndarray-iter #(instance? NDArrayIter %)) +(s/def ::train-data (s/or :mx-iter ::mx-data-iter :ndarry-iter ::ndarray-iter)) +(s/def ::eval-data ::train-data) +(s/def ::num-epoch int?) +(s/def ::fit-params #(instance? FitParams %)) +(s/def ::fit-options (s/keys :req-un [::train-data] :opt-un [::eval-data ::num-epoch ::fit-params])) + + +;;; High Level API + +(defn score + " Run prediction on `eval-data` and evaluate the performance according to `eval-metric`. + - mod module + - option map with + :eval-data : DataIter + :eval-metric : EvalMetric + :num-batch Number of batches to run. Default is `Integer.MAX_VALUE`, + indicating run until the `DataIter` finishes. + :batch-end-callback -not supported yet + :reset Default `True`, + indicating whether we should reset `eval-data` before starting evaluating. + :epoch Default 0. For compatibility, this will be passed to callbacks (if any). + During training, this will correspond to the training epoch number." + [mod {:keys [eval-data eval-metric num-batch reset epoch] :as opts + :or {num-batch Integer/MAX_VALUE + reset true + epoch 0}}] + (util/validate! ::score-opts opts "Invalid score options") + (do (eval-metric/reset eval-metric) + (eval-metric/get + (.score mod + eval-data + eval-metric + (int num-batch) + (util/->option nil) + (util/->option nil) + reset + (int epoch))))) + +(defn fit + "Train the module parameters. + - mod module + - train-data (data-iterator) + - eval-data (data-iterator)If not nil, will be used as validation set and evaluate + the performance after each epoch. + - num-epoch Number of epochs to run training. + - f-params Extra parameters for training (See fit-params)." + [mod {:keys [train-data eval-data num-epoch fit-params] :as opts + ` :or {num-epoch 1 + fit-params (new FitParams)}}] + (util/validate! ::fit-options opts "Invalid options for fit") + (let [fmod (-> mod + (bind {:data-shapes (mx-io/provide-data train-data) + :label-shapes (mx-io/provide-label train-data) + :for-training true + :force-rebind (.forceRebind fit-params)}) + (init-params (remove (fn [[k v]] (nil? v)) + {:initializer (.initializer fit-params) + :arg-params (.argParams fit-params) + :aux-params (.auxParams fit-params) + :allow-missing (.allowMissing fit-params)})) + (init-optimizer (remove (fn [[k v]] (nil? v)) + {:optimizer (.optimizer fit-params) + :kvstore (.kvstore fit-params)}))) + eval-metric (or (.evalMetric fit-params) (eval-metric/accuracy)) + val-metric (or (util/option->value (.validationMetric fit-params)) (eval-metric/accuracy))] + (doseq [i (range num-epoch)] + (let [tic (System/currentTimeMillis)] + (mx-io/reduce-batches train-data + (fn [batch-num batch] + (-> fmod + (forward batch) + (backward) + (update) + (update-metric eval-metric (mx-io/batch-label batch))) + (when-let [cb (util/option->value (.batchEndCallback fit-params))] + (callback/invoke cb i batch-num eval-metric)) + (.dispose batch) + (inc batch-num))) + (println "Epoch " i " Train-" (eval-metric/get eval-metric)) + (println "Epoch " i " Time cost-" (- (System/currentTimeMillis) tic)) + + ;;sync across kvstores + (get-params fmod) + (when-let [cb (util/option->value (.epochEndCallback fit-params))] + (callback/invoke cb i 0 val-metric)) + + ;; evaluation on the validation set + (when eval-data + (let [res (score fmod {:eval-data eval-data :eval-metric eval-metric :epoch i})] + (println "Epoch " i " Validation- " res))))) + fmod) + ;; old way if the problem with the sizes get resolved in DataDesc + #_(doto mod + (.fit + train-data + (util/->option eval-data) + (int num-epoch) + fit-params))) + +(s/def ::eval-data ::train-data) +(s/def ::num-batch integer?) +(s/def ::reset boolean?) +(s/def ::predict-opts (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset])) + +(defn predict-batch + "Run the predication on a data batch + - mod module + - data-batch data-batch" + [mod data-batch] + (util/validate! ::mx-io/data-batch data-batch "Invalid data batch") + (util/coerce-return (.predict mod (if (map? data-batch) + (mx-io/data-batch data-batch) + data-batch)))) + +(defn predict + "Run prediction and collect the outputs. + - mod module + - option map with + - :eval-data + - :num-batch Default is -1, indicating running all the batches in the data iterator. + - :reset Default is `True`, indicating whether we should reset the data iter before start + doing prediction. + The return value will be a vector of NDArrays `[out1, out2, out3]`. + Where each element is concatenation of the outputs for all the mini-batches." + [mod {:keys [eval-data num-batch reset] :as opts + :or {num-batch -1 + reset true}}] + (util/validate! ::predict-opts opts "Invalid opts for predict") + (util/scala-vector->vec (.predict mod eval-data (int num-batch) reset))) + + +(s/def ::predict-every-batch-opts (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset])) + +(defn predict-every-batch + " Run prediction and collect the outputs. + - module + - option map with + :eval-data + :num-batch Default is -1, indicating running all the batches in the data iterator. + :reset Default is `True`, indicating whether we should reset the data iter before start + doing prediction. + The return value will be a nested list like + [[out1_batch1, out2_batch1, ...], [out1_batch2, out2_batch2, ...]]` + This mode is useful because in some cases (e.g. bucketing), + the module does not necessarily produce the same number of outputs." + [mod {:keys [eval-data num-batch reset] :as opts + :or {num-batch -1 + reset true}}] + (util/validate! ::predict-every-batch-opts opts "Invalid opts for predict-every-batch") + (mapv util/scala-vector->vec (util/scala-vector->vec (.predictEveryBatch mod eval-data (int num-batch) reset)))) + + + +(s/def ::score-opts (s/keys :req-un [::eval-data ::eval-metric] :opt-un [::num-batch ::reset ::epoch])) + + +(defn exec-group [mod] + (.execGroup mod)) + +(defn grad-arrays [mod] + (mapv vec (util/buffer->vec (.gradArrays (.execGroup mod))))) + +(comment + (require '[clojure.reflect :as r]) + (r/reflect DataDesc) + (new DataDesc) + + (.setEpochEndCallback (if epoch-end-callback + (util/->option epoch-end-callback) + (util/->option nil))) + (.setBatchEndCallback (if batch-end-callback + (util/->option batch-end-callback) + (util/->option nil))) + + (fit-params {:allow-missing true}) + (fit-params {}) + + ) + + diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/monitor.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/monitor.clj new file mode 100644 index 000000000000..0550b4c46533 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/monitor.clj @@ -0,0 +1,43 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.monitor + (:require [org.apache.clojure-mxnet.util :as util]) + (:import (org.apache.mxnet Monitor))) + + +(defmacro monitor + "Monitor outputs, weights, and gradients for debugging. + - interval Number of batches between printing. + - stat-func A function that computes statistics of tensors. + Takes a NDArray and returns a NDArray. defaults + to mean absolute value |x|/size(x). Function must be in the form of clojure (fn [x])" + [interval stat-fun] + `(new Monitor (int ~interval) (util/scala-fn ~stat-fun))) + +(defn tic + "Start collecting stats for current batch. + Call before forward" + [monitor] + (doto monitor + (.tic))) + +(defn toc + "End collecting for current batch and return results. + Call after computation of current batch." + [monitor] + (map util/tuple->vec (util/scala-vector->vec (.toVector (.toc monitor))))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj new file mode 100644 index 000000000000..b471055cb816 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj @@ -0,0 +1,171 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.ndarray + (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max + min repeat reverse set sort take to-array empty shuffle]) + (:require [org.apache.clojure-mxnet.base :as base] + [org.apache.clojure-mxnet.context :as mx-context] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util] + [clojure.reflect :as r] + [t6.from-scala.core :refer [$] :as $]) + (:import (org.apache.mxnet NDArray))) + + +;; loads the generated functions into the namespace +(do (clojure.core/load "gen/ndarray")) + +(defn ->vec + "Converts a nd-array to a vector (one dimensional)" + [ndarray] + (-> ndarray to-array aclone vec)) + +(defn empty + "Create an empty uninitialized new NDArray, with specified shape" + ([shape-vec {:keys [ctx dtype] + :or {ctx (mx-context/default-context) dtype base/MX_REAL_TYPE} + :as opts}] + (NDArray/empty (mx-shape/->shape shape-vec) ctx dtype)) + ([shape-vec] + (empty shape-vec {}))) + +(defn zeros + "Create a new NDArray filled with 0, with specified shape." + ([shape-vec {:keys [ctx dtype] + :or {ctx (mx-context/default-context) dtype base/MX_REAL_TYPE} + :as opts}] + (NDArray/zeros (mx-shape/->shape shape-vec) ctx dtype)) + ([shape-vec] + (zeros shape-vec {}))) + +(defn ones + "Create a new NDArray filled with 1, with specified shape." + ([shape-vec {:keys [ctx dtype] + :or {ctx (mx-context/default-context) dtype base/MX_REAL_TYPE} + :as opts}] + (NDArray/ones (mx-shape/->shape shape-vec) ctx dtype)) + ([shape-vec] + (ones shape-vec {}))) + +(defn full + "Create a new NDArray filled with given value, with specified shape." + ([shape-vec value {:keys [ctx dtype] + :or {ctx (mx-context/default-context)} + :as opts}] + (NDArray/full (mx-shape/->shape shape-vec) value ctx)) + ([shape-vec value] + (full shape-vec value {}))) + +(defn array + "Create a new NDArray that copies content from source vector" + ([source-vec shape-vec {:keys [ctx dtype] + :or {ctx (mx-context/default-context)} + :as opts}] + (NDArray/array (float-array source-vec) (mx-shape/->shape shape-vec) ctx)) + ([source-vec shape-vec] + (array source-vec shape-vec {}))) + + +(defn arange + "Returns evenly spaced values within a given interval. + Values are generated within the half-open interval [`start`, `stop`). In other + words, the interval includes `start` but excludes `stop`." + ([start stop {:keys [step repeat ctx dtype] + :or {step (float 1) repeat (int 1) ctx (mx-context/default-context) dtype base/MX_REAL_TYPE} + :as opts}] + (NDArray/arange (float start) ($/option(float stop)) step repeat ctx dtype)) + ([start stop] + (arange start stop {}))) + +(defn slice + "Return a sliced NDArray that shares memory with current one." + ([ndarray i] + (.slice ndarray (int i))) + ([ndarray start stop] + (.slice ndarray (int start) (int stop)))) + +(defn copy-to + "Copy the content of current array to other" + [source-ndarray target-ndarray] + (.copyTo source-ndarray target-ndarray)) + + +(defn save + "Save list of NDArray or dict of str->NDArray to binary file + (The name of the file.Can be S3 or HDFS address (remember built with S3 support)) + Example of fname: + * - `s3://my-bucket/path/my-s3-ndarray` + * - `hdfs://my-bucket/path/my-hdfs-ndarray` + * - `/path-to/my-local-ndarray`" + [fname map-of-name-to-ndarray] + (NDArray/save fname (util/coerce-param map-of-name-to-ndarray #{"scala.collection.immutable.Map"}))) + +(defn load + "Takes a filename and returns back a map of ndarray-name to ndarray" + [filename] + (let [info (NDArray/load filename) + [names ndarrays] (util/tuple->vec info)] + (into {} (map (fn [n a] {(str n) a}) names ndarrays)))) + +(defn save-to-file + "Save one ndarray to a file" + [fname ndarray] + (save fname {"default" ndarray})) + +(defn load-from-file + "Load one ndarry from a file" + [fname] + (first (load2-array fname))) + +(defn as-in-context + "Return an `NDArray` that lives in the target context. If the array + is already in that context, `self` is returned. Otherwise, a copy is made." + [ndarray ctx] + (.asInContext ndarray ctx)) + +(defn as-type + "Return a copied numpy array of current array with specified type." + [ndarray dtype] + (.asType ndarray dtype)) + +(defn / [ndarray num-or-NDArray] + (div ndarray num-or-NDArray)) + +(defn concatenate + ([ndarrays {:keys [axis always-copy] :or {axis 1 always-copy true}}] + (NDArray/concatenate (apply $/immutable-list ndarrays) (int axis) always-copy)) + ([ndarrays] + (NDArray/concatenate (apply $/immutable-list ndarrays)))) + +(defn ->raw [ndarray] + (-> ndarray internal .getRaw)) + +(defn ->float-vec [ndarray] + (-> ndarray internal .toFloatArray vec)) + +(defn ->int-vec [ndarray] + (-> ndarray internal .toIntArray vec)) + +(defn ->double-vec [ndarray] + (-> ndarray internal .toDoubleArray vec)) + +(defn ->byte-vec [ndarray] + (-> ndarray internal .toByteArray vec)) + +(defn shape-vec [ndarray] + (mx-shape/->vec (shape ndarray))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj new file mode 100644 index 000000000000..45dcc484eecb --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj @@ -0,0 +1,178 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.optimizer + (:refer-clojure :exclude [update]) + (:import (org.apache.mxnet.optimizer SGD DCASGD NAG AdaDelta RMSProp AdaGrad Adam SGLD))) + +(defn sgd + "A very simple SGD optimizer with momentum and weight regularization." + ([{:keys [learning-rate momentum wd clip-gradient lr-scheduler] :as opts + :or {learning-rate 0.01 + momentum 0.0 + wd 0.0001 + clip-gradient 0}}] + (new SGD (float learning-rate) (float momentum) (float wd) (float clip-gradient) lr-scheduler)) + ([] + (sgd {}))) + +(defn dcasgd + "DCASGD optimizer with momentum and weight regularization. + Implementation of paper 'Asynchronous Stochastic Gradient Descent with + Delay Compensation for Distributed Deep Learning'" + ([{:keys [learning-rate momentum lambda wd clip-gradient lr-scheduler] :as opts + :or {learning-rate 0.01 + momentum 0.0 + lambda 0.04 + wd 0.0 + clip-gradient 0}}] + (new DCASGD (float learning-rate) (float lambda) (float momentum) (float wd) (float clip-gradient) lr-scheduler)) + ([] + (dcasgd {}))) + +(defn nag + "SGD with nesterov. + It is implemented according to + https://github.com/torch/optim/blob/master/sgd.lua" + ([{:keys [learning-rate momentum wd clip-gradient lr-scheduler] :as opts + :or {learning-rate 0.01 + momentum 0.0 + wd 0.0001 + clip-gradient 0}}] + (new NAG (float learning-rate) (float momentum) (float wd) (float clip-gradient) lr-scheduler)) + ([] + (nag {}))) + +(defn ada-delta + "AdaDelta optimizer as described in Matthew D. Zeiler, 2012. + http://arxiv.org/abs/1212.5701" + ([{:keys [rho rescale-gradient epsilon wd clip-gradient] :as opts + :or {rho 0.05 + rescale-gradient 1.0 + epsilon 1e-8 + wd 0.0 + clip-gradient 0}}] + (new AdaDelta (float rho) (float rescale-gradient) (float epsilon) (float wd) (float clip-gradient))) + ([] + (ada-delta {}))) + +(defn rms-prop + "RMSProp optimizer as described in Tieleman & Hinton, 2012. + http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by Alex Graves, 2013. + - learningRate Step size. + - gamma1 decay factor of moving average for gradient, gradient^^2. + - gamma2 momentum factor of moving average for gradient. + - rescale-gradient rescaling factor of gradient. + - wd L2 regularization coefficient add to all the weights + - clip-gradient clip gradient in range [-clip_gradient, clip_gradient] + - lr-scheduler The learning rate scheduler" + ([{:keys [learning-rate rescale-gradient gamma1 gamma2 wd lr-scheduler clip-gradient] + :or {learning-rate 0.002 + rescale-gradient 1.0 + gamma1 0.95 + gamma2 0.9 + wd 0.0 + clip-gradient 0}}] + (new RMSProp (float learning-rate) (float rescale-gradient) (float gamma1) + (float gamma2) (float wd) lr-scheduler (float clip-gradient))) + ([] + (rms-prop {}))) + +(defn ada-grad + " AdaGrad optimizer as described in Duchi, Hazan and Singer, 2011. + http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf + + - learning-rate Step size. + - epsilon A small number to make the updating processing stable. + Default value is set to 1e-7. + - rescale-gradient rescaling factor of gradient. + - wd L2 regularization coefficient add to all the weights" + ([{:keys [learning-rate rescale-gradient epsilon wd] + :or {learning-rate 0.05 + rescale-gradient 1.0 + epsilon 1e-7 + wd 0.0}}] + (new AdaGrad (float learning-rate) (float rescale-gradient) (float epsilon) (float wd) )) + ([] + (ada-grad {}))) + +(defn adam + "Adam optimizer as described in [King2014] + + [King2014] Diederik Kingma, Jimmy Ba, + Adam: A Method for Stochastic Optimization, + http://arxiv.org/abs/1412.6980 + + - learning-rate Step size. + - beta1 Exponential decay rate for the first moment estimates. + - beta2 Exponential decay rate for the second moment estimates. + - epsilon + - decay-factor + - wd L2 regularization coefficient add to all the weights + - clip-gradient clip gradient in range [-clip_gradient, clip_gradient] + - lr-scheduler The learning rate scheduler" + ([{:keys [learning-rate beta1 beta2 epsilon decay-factor wd clip-gradient lr-scheduler] + :or {learning-rate 0.002 + beta1 0.9 + beta2 0.999 + epsilon 1e-8 + decay-factor (- 1 1e-8) + wd 0 + clip-gradient 0}}] + (new Adam (float learning-rate) (float beta1) (float beta2) (float epsilon) + (float decay-factor) (float wd) (float clip-gradient) lr-scheduler)) + ([] + (adam {}))) + +(defn sgld + "Stochastic Langevin Dynamics Updater to sample from a distribution. + + - learning-rate Step size. + - rescale-gradient rescaling factor of gradient. + - wd L2 regularization coefficient add to all the weights + - clip-gradient Float, clip gradient in range [-clip_gradient, clip_gradient] + - lr-scheduler The learning rate scheduler" + ([{:keys [learning-rate rescale-gradient wd clip-gradient lr-scheduler] + :or {learning-rate 0.01 + rescale-gradient 1 + wd 0.0001 + clip-gradient 0}}] + (new SGLD (float learning-rate) (float rescale-gradient) (float wd) + (float clip-gradient) lr-scheduler)) + ([] + (sgld {}))) + +(defn update + "Update the parameters. + - optimizer - the optimizer + - index An unique integer key used to index the parameters + - weight weight ndarray + - grad grad ndarray + - state NDArray or other objects returned by initState + The auxiliary state used in optimization. + " + [optimizer index weight grad state] + (doto optimizer + (.update (int index) weight grad state))) + + +(defn create-state + "Create additional optimizer state such as momentum." + [optimizer index weight] + (do + (.createState optimizer (int index) weight))) + diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/profiler.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/profiler.clj new file mode 100644 index 000000000000..0bc93cc558e5 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/profiler.clj @@ -0,0 +1,47 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.profiler + (:import (org.apache.mxnet Profiler)) + (:require [org.apache.clojure-mxnet.util :as util])) + +(defn profiler-set-config + " Set up the configure of profiler. + -mode, optional Indicting whether to enable the profiler, can + be symbolic or all. Default is symbolic. + -fileName, optional The name of output trace file. Default is profile.json." + [kwargs] + (Profiler/profilerSetConfig + (util/convert-io-map kwargs) )) + +(defn profiler-set-state + "Set up the profiler state to record operator. + -state, optional + - Indicting whether to run the profiler, can + be stop or run. Default is stop." + ([state] + (Profiler/profilerSetState state)) + ([] + (profiler-set-state false))) + +(defn dump-profile + " Dump profile and stop profiler. Use this to save profile + in advance in case your program cannot exit normally." + ([finished] + (Profiler/dumpProfile (int finished))) + ([] + (dump-profile 1))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj new file mode 100644 index 000000000000..99f09aab993f --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj @@ -0,0 +1,62 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.random + (:require [org.apache.clojure-mxnet.shape :as mx-shape]) + (:import (org.apache.mxnet Random))) + +(defn uniform + "Generate uniform distribution in [low, high) with shape. + low: The lower bound of distribution. + high: The upper bound of distribution. + shape-vec: vector shape of the ndarray generated. + opts-map { + ctx: Context of output ndarray, will use default context if not specified. + out: Output place holder} + returns: The result ndarray with generated result./" + ([low high shape-vec {:keys [ctx out] :as opts}] + (Random/uniform (float low) (float high) (mx-shape/->shape shape-vec) ctx out)) + ([low high shape-vec] + (uniform low high shape-vec {}))) + +(defn normal + "Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape. + loc: The standard deviation of the normal distribution + scale: The upper bound of distribution. + shape-vec: vector shape of the ndarray generated. + opts-map { + ctx: Context of output ndarray, will use default context if not specified. + out: Output place holder} + returns: The result ndarray with generated result./" + ([loc scale shape-vec {:keys [ctx out] :as opts}] + (Random/normal (float loc) (float scale) (mx-shape/->shape shape-vec) ctx out)) + ([loc scale shape-vec] + (normal loc scale shape-vec {}))) + + +(defn seed + " Seed the random number generators in mxnet. + This seed will affect behavior of functions in this module, + as well as results from executors that contains Random number + such as Dropout operators. + + seed-state: The random number seed to set to all devices. + note: The random number generator of mxnet is by default device specific. + This means if you set the same seed, the random number sequence + generated from GPU0 can be different from CPU." + [seed-state] + (Random/seed (int seed-state))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/shape.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/shape.clj new file mode 100644 index 000000000000..684cb0fc56eb --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/shape.clj @@ -0,0 +1,35 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.shape + (:require [t6.from-scala.core :refer [$] :as $]) + (:import (org.apache.mxnet Shape))) + + +(defn ->shape [v] + (new Shape (apply $/immutable-list (map int v)))) + +(defn ->vec [shape-obj] + (-> shape-obj + .toArray + vec)) + +(defn length [shape-obj] + (.length shape-obj)) + +(defn product [shape] + (.product shape)) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj new file mode 100644 index 000000000000..2dcd84756339 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj @@ -0,0 +1,239 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.symbol + (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max + min repeat reverse set sort take to-array empty sin + get apply shuffle]) + (:require [org.apache.clojure-mxnet.base :as base] + [org.apache.clojure-mxnet.context :as mx-context] + [org.apache.clojure-mxnet.executor :as ex] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util] + [t6.from-scala.core :refer [$] :as $] + [org.apache.clojure-mxnet.ndarray :as ndarray]) + (:import (org.apache.mxnet Symbol))) + + +;; loads the generated functions into the namespace +(do (clojure.core/load "gen/symbol")) + +;;;;;; + +(defn variable + "Create a symbolic variable with a specified name. + attr-map: Additional attributes to set on the variable + shape-vec: The shape vector of the variable. If specified, this will be used during shape inference. + lr-mult: The learning rate multiplier + wd-mult: The weight decay multiplier for the input variable + dtype: The dtype for the input variable + kwarg-map: Additional attributes which must start and end with double underscores" + ([var-name] + (variable var-name {})) + ([var-name {:keys [attrs shape lr-mult wd-mult dtype kwargs]:as opts}] + (Symbol/Variable var-name + (when attrs (util/convert-symbol-map attrs)) + (when shape (mx-shape/->shape shape)) + (if lr-mult (float lr-mult)($/option nil)) + (if wd-mult (float wd-mult)($/option nil)) + dtype + (if kwargs (util/convert-symbol-map kwargs) (util/empty-map))))) + +(defn bind + "Bind the current symbol to get an executor. + sym: symbol + ctx: the device context of the generated executor to run on + bind-map: map of str to ndarray + bind-grad-map: map of str to ndarray" + ([sym ctx bind-map-or-vec bind-grads-map-or-vec grad-req bind-aux-map-or-vec] + (.bind sym + ctx + (util/coerce-param bind-map-or-vec #{"scala.collection.immutable.Map" "scala.collection.Seq"}) + (util/coerce-param bind-grads-map-or-vec #{"scala.collection.immutable.Map" "scala.collection.Seq"}) + grad-req + (util/coerce-param bind-aux-map-or-vec #{"scala.collection.immutable.Map" "scala.collection.Seq"}) + nil + nil)) + ([sym ctx bind-map-or-vec bind-grads-map-or-vec] + (.bind sym + ctx + (util/coerce-param bind-map-or-vec #{"scala.collection.immutable.Map" "scala.collection.Seq"}) + (util/coerce-param bind-grads-map-or-vec #{"scala.collection.immutable.Map" "scala.collection.Seq"}))) + ([sym ctx bind-map-or-vec] + (.bind sym + ctx + (util/coerce-param bind-map-or-vec #{"scala.collection.immutable.Map" "scala.collection.Seq"}) + nil)) + ([sym bind-map-or-vec] + (.bind sym + (mx-context/default-context) + (util/coerce-param bind-map-or-vec #{"scala.collection.immutable.Map" "scala.collection.Seq"})))) + +(defn simple-bind + " Bind current symbol to get an executor, allocate all the ndarrays needed. + Allows specifying data types. + This function will ask user to pass in ndarray of position + they like to bind to, and it will automatically allocate the ndarray + for arguments and auxiliary states that user did not specify explicitly. + + ctx: The device context the generated executor to run on. + shape-vec-map: map of name->shape + opt-map: options map of: + :grad-req {'write', 'add', 'null'}, or list of str or dict of str to str, optional + Specifies how we should update the gradient to the args_grad. + - 'write' means everytime gradient is write to specified args_grad NDArray. + - 'add' means everytime gradient is add to the specified NDArray. + - 'null' means no action is taken, the gradient may not be calculated. + :type-map map of name->dtype. + Will return the generator" + ([sym ctx shape-vec-map {:keys [grad-req type-map] :as opts + :or {grad-req "write"}}] + (let [shape-map (->> shape-vec-map + (map (fn [[k v]] [k (mx-shape/->shape v)])) + (into {}))] + (.simpleBind sym ctx grad-req + (util/nil-or-coerce-param shape-map #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param type-map #{"scala.collection.immutable.Map"})))) + ([sym ctx shape-vec-map] + (simple-bind sym ctx shape-vec-map {})) + ([sym ctx] + (.simpleBind sym ctx "write" (util/empty-map) nil))) + +(defn ones + "Returns a new symbol of given shape and type, filled with ones" + ([shape-vec {:keys [ctx dtype] :as optss + :or {ctx nil dtype base/MX_REAL_TYPE}}] + (Symbol/ones (mx-shape/->shape shape-vec) dtype ctx)) + ([shape-vec] + (ones shape-vec {}))) + +(defn zeros + "Returns a new symbol of given shape and type, filled with zeros" + ([shape-vec {:keys [ctx dtype] :as opts + :or {ctx nil dtype base/MX_REAL_TYPE}}] + (Symbol/zeros (mx-shape/->shape shape-vec) dtype ctx)) + ([shape-vec] + (zeros shape-vec {}))) + +(defn arange + "Returns evenly spaced values within a given interval. + Values are generated within the half-open interval [`start`, `stop`). In other + words, the interval includes `start` but excludes `stop`." + ([start stop {:keys [step repeat dtype] + :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} + :as opts}] + (Symbol/arange (float start) ($/option(float stop)) step repeat nil dtype)) + ([start stop] + (arange start stop {}))) + +;;; manually defined because of a conflicting arity of 2 with the auto-gen +(defn min + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/min + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (min sym-name attr-map (util/empty-list) kwargs-map)) + ([kwargs-map] (min nil nil (util/empty-list) kwargs-map)) + ([sym1 sym2] + (util/coerce-return + (Symbol/min + (util/nil-or-coerce-param + sym1 + #{"ml.dmlc.mxnet.Symbol" "java.lang.Object"}) + (util/nil-or-coerce-param + sym2 + #{"ml.dmlc.mxnet.Symbol" "java.lang.Object"}))))) + +;;; manually defined because of a conflicting arity of 2 with the auto-gen + +(defn max + ([sym1 sym2] + (util/coerce-return + (Symbol/max + (util/nil-or-coerce-param + sym1 + #{"ml.dmlc.mxnet.Symbol" "java.lang.Object"}) + (util/nil-or-coerce-param + sym2 + #{"ml.dmlc.mxnet.Symbol" "java.lang.Object"})))) + ([sym-name kwargs-map symbol-list kwargs-map-1] + (util/coerce-return + (Symbol/max + (util/nil-or-coerce-param sym-name #{"java.lang.String"}) + (util/nil-or-coerce-param + kwargs-map + #{"scala.collection.immutable.Map"}) + (util/nil-or-coerce-param symbol-list #{"scala.collection.Seq"}) + (util/nil-or-coerce-param + kwargs-map-1 + #{"scala.collection.immutable.Map"})))) + ([sym-name attr-map kwargs-map] + (max sym-name attr-map (util/empty-list) kwargs-map)) + ([kwargs-map] (max nil nil (util/empty-list) kwargs-map))) + +;;; redefining to make it easier to work with + +(defn- coerce-infer-shape-return [ret] + (->> ret + (map util/scala-vector->vec) + (map (fn [shapes] (map mx-shape/->vec shapes))))) + +(defn + infer-shape + ([sym vec-or-strings vec-of-ints vec-of-ints-1] + (let [ret (util/coerce-return + (.inferShape + sym + (util/nil-or-coerce-param vec-or-strings #{"java.lang.String<>"}) + (util/nil-or-coerce-param vec-of-ints #{"int<>"}) + (util/nil-or-coerce-param vec-of-ints-1 #{"int<>"})))] + (coerce-infer-shape-return ret))) + ([sym symbol-list-or-kwargs-map] + (let [ret (util/coerce-return + (.inferShape + sym + (if (map? symbol-list-or-kwargs-map) + (util/convert-shape-map symbol-list-or-kwargs-map) + (util/nil-or-coerce-param + symbol-list-or-kwargs-map + #{"scala.collection.Seq" "scala.collection.immutable.Map"}))))] + (coerce-infer-shape-return ret)))) + +(defn + save-checkpoint + "Taken from the model save checkpoint" + [prefix epoch sym arg-params aux-params] + (do + (save sym (str prefix "-symbol.json")) + (let [save-map (merge (->> arg-params + (mapv (fn [[k v]] [(str "arg:" k) v])) + (into {})) + (->> aux-params + (mapv (fn [[k v]] [(str "aux:" k) v])) + (into {}))) + param-name (format "%s-%04d.params" prefix epoch)] + (ndarray/save param-name save-map) + (println "Saved checkpoint to " param-name))) + ) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj new file mode 100644 index 000000000000..f42a1248f236 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj @@ -0,0 +1,208 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.util + (:require [clojure.spec.alpha :as s] + [t6.from-scala.core :refer [$ $$] :as $] + [clojure.string :as string] + [org.apache.clojure-mxnet.shape :as mx-shape]) + (:import (org.apache.mxnet NDArray) + (scala Product Tuple2 Tuple3) + (scala.collection.immutable List IndexedSeq ListMap) + (scala.collection JavaConversions Map) + (scala Option))) + +(def ndarray-param-coerce {"float" "num" + "int" "num" + "boolean" "bool" + "scala.collection.immutable.Map" "kwargs-map" + "scala.collection.Seq" "& nd-array-and-params" + "int<>" "vec-of-ints" + "float<>" "vec-of-floats" + "byte<>" "byte-array" + "java.lang.String<>" "vec-or-strings" + "org.apache.mxnet.NDArray" "ndarray" + "org.apache.mxnet.Symbol" "sym"}) + +(def symbol-param-coerce {"java.lang.String" "sym-name" + "float" "num" + "int" "num" + "boolean" "bool" + "scala.collection.immutable.Map" "kwargs-map" + "scala.collection.Seq" "symbol-list" + "int<>" "vec-of-ints" + "float<>" "vec-of-floats" + "byte<>" "byte-array" + "java.lang.String<>" "vec-or-strings" + "org.apache.mxnet.Symbol" "sym" + "java.lang.Object" "object"}) + +(defn empty-list [] + ($ List/empty)) + +(defn empty-map [] + ($ Map/empty)) + +(defn empty-indexed-seq [] + ($ IndexedSeq/empty)) + +(defn empty-list-map [] + ($ ListMap/empty)) + +(defn ->option [v] + ($ Option v)) + +(defn option->value [opt] + ($/view opt)) + +(defn keyword->snake-case [vals] + (mapv (fn [v] (if (keyword? v) (string/replace (name v) "-" "_") v)) vals)) + +(defn convert-tuple [param] + (apply $/tuple param)) + + +(def tuple-param-names #{"kernel" "stride" "pad" "target-shape" "shape"}) + +(defn convert-by-shape [param] + (into {} (mapv (fn [[k v]] + [k (if (vector? v) (mx-shape/->shape v) v)]) + param))) + +(defn tuple-convert-by-param-name [param] + (into {} (mapv (fn [[k v]] + (if (or (get tuple-param-names k) + (get tuple-param-names (name k))) + [k (str (if (vector? v) (mx-shape/->shape v) v))] + [k v])) + param))) + +(def io-param-names #{"input-shape" "data-shape" "label-shape"}) + +(defn io-convert-by-param-name [param] + (into {} (mapv (fn [[k v]] (cond + (or (get io-param-names k) + (get io-param-names (name k))) [k (str (if (vector? v) (mx-shape/->shape v) v))] + (true? v) [k "True"] + (false? v) [k "False"] + :else [k (str v)])) + param))) + +(defn convert-map [param] + (if (empty? param) + (empty-map) + (apply $/immutable-map (->> param + (into []) + flatten + keyword->snake-case)))) + + +(defn convert-symbol-map [param] + (convert-map (tuple-convert-by-param-name param))) + +(defn convert-io-map [param] + (convert-map (io-convert-by-param-name param))) + +(defn convert-shape-map [param] + (convert-map (convert-by-shape param))) + +(defn convert-vector [param] + (apply $/immutable-list param)) + +(defn vec->set [param] + (apply $/immutable-set param)) + +(defn vec->indexed-seq [x] + (.toIndexedSeq (convert-vector x))) + +(defn apply-scala-fn [f args] + (.apply f args)) + +(defn coerce-param [param targets] + (cond + (and (get targets "scala.collection.immutable.Map") (map? param)) (convert-map param) + (and (get targets "float") (number? param)) (float param) + (and (get targets "scala.collection.Seq") (instance? org.apache.mxnet.NDArray param)) ($/immutable-list param) + (and (get targets "scala.collection.Seq") (instance? org.apache.mxnet.Symbol param)) ($/immutable-list param) + (and (get targets "scala.collection.Seq") (and (or (vector? param) (seq? param)) (empty? param))) (empty-list) + (and (get targets "scala.collection.Seq") (or (vector? param) (seq? param))) (apply $/immutable-list param) + (and (get targets "int<>") (vector? param)) (int-array param) + (and (get targets "float<>") (vector? param)) (float-array param) + (and (get targets "java.lang.String<>") (vector? param)) (into-array param) + :else param)) + +(defn nil-or-coerce-param [param targets] + (when param + (coerce-param param targets))) + +(defn scala-map->map + [^Map m] + (into {} (JavaConversions/mapAsJavaMap m))) + +(defn buffer->vec [b] + (into [] (JavaConversions/bufferAsJavaList b))) + +(defn scala-vector->vec [x] + (into [] (JavaConversions/asJavaCollection x))) + +(defn scala-iterator->seq [x] + (iterator-seq (JavaConversions/asJavaIterator x))) + +(defn tuple->vec [^Product p] + (->> (.productArity p) + (range) + (map #(.productElement p %)) + (into []))) + +(defn coerce-return [return-val] + (cond + (instance? scala.collection.mutable.ArrayBuffer return-val) (buffer->vec return-val) + (instance? scala.collection.immutable.Vector return-val) (scala-vector->vec return-val) + (instance? org.apache.mxnet.NDArrayFuncReturn return-val) (.head return-val) + (instance? Map return-val) (scala-map->map return-val) + (instance? Tuple2 return-val) (tuple->vec return-val) + (instance? Tuple3 return-val) (tuple->vec return-val) + :else return-val)) + +(defmacro scala-fn + "Creates a scala fn from an anonymous clojure fn of the form (fn [x] body)" + [f] + `($/fn ~@(drop-last (rest f)) ~(last f))) + +(defn translate-keyword-shape[[k v]] + [(if (keyword? k) (string/replace (name k) "-" "_") k) + (if (vector? v) (mx-shape/->shape v) v)]) + +(defn map->tuple [m] + (->> m + (into []) + (map translate-keyword-shape) + (map convert-tuple))) + +(defn list-map [m] + (loop [lm ($ ListMap/empty) + tuples (map->tuple m)] + (if (seq tuples) + (recur ($ lm "+" (first tuples)) (rest tuples)) + lm))) + +(defn validate! [spec value error-msg] + (when-not (s/valid? spec value) + (s/explain spec value) + (throw (ex-info error-msg + (s/explain-data spec value))))) + diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/visualization.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/visualization.clj new file mode 100644 index 000000000000..c7002bf676ad --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/visualization.clj @@ -0,0 +1,63 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.visualization + (:require [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.shape :as mx-shape]) + (:import (org.apache.mxnet Visualization))) + + + +(defn plot-network + "convert symbol to Dot object for visualization + - symbol symbol to be visualized + - title title of the dot graph + - shape-map Map of shapes, str -> shape, given input shapes + - node-attrs Map of node's attributes + for example: {:shape \"oval\" :fixedsize \"false\"} + + - hide-weight if true (default) then inputs with names like `*_weight` + or `*_bias` will be hidden + returns Dot object of symbol" + ([sym shape-map {:keys [title node-attrs hide-weights] :as opts + :or {title "plot" + hide-weights true}}] + (Visualization/plotNetwork sym + title + (->> shape-map + (map (fn [[k v]] [k (mx-shape/->shape v)])) + (into {}) + (util/convert-map)) + (util/convert-map node-attrs) + hide-weights)) + ([sym shape-map] + (plot-network sym shape-map {}))) + + +(defn render + " Render file with Graphviz engine into format. + - dot the dot file from plot-network function + - engine The layout commmand used for rendering ('dot', 'neato', ...). + - format The output format used for rendering ('pdf', 'png', ...). + - filename Name of the DOT source file to render. + - path Path to save the Dot source file. + " + ([dot engine format filename path] + (doto dot + (.render engine format filename path))) + ([dot filename path] + (render dot "dot" "pdf" filename path))) diff --git a/contrib/clojure-package/test/dev/generator_test.clj b/contrib/clojure-package/test/dev/generator_test.clj new file mode 100644 index 000000000000..b6f5f43f72a3 --- /dev/null +++ b/contrib/clojure-package/test/dev/generator_test.clj @@ -0,0 +1,193 @@ +(ns dev.generator-test + (:require [clojure.test :refer :all] + [dev.generator :as gen])) + +(deftest test-clojure-case + (is (= "foo-bar" (gen/clojure-case "FooBar"))) + (is (= "foo-bar-baz" (gen/clojure-case "FooBarBaz"))) + (is (= "foo-bar-baz" (gen/clojure-case "FOOBarBaz"))) + (is (= "foo-bar" (gen/clojure-case "foo_bar"))) + (is (= "foo-bar" (gen/clojure-case "Foo_Bar"))) + (is (= "div+" (gen/clojure-case "/+")))) + +(defn ndarray-reflect-info [name] + (->> gen/ndarray-public-no-default + (filter #(= name (str (:name %)))) + first)) + +(defn symbol-reflect-info [name] + (->> gen/symbol-public-no-default + (filter #(= name (str (:name %)))) + first)) + +(deftest test-symbol-transform-param-name + (let [params ["java.lang.String" + "scala.collection.immutable.Map" + "scala.collection.Seq" + "scala.collection.immutable.Map"] + transformed-params ["sym-name" + "kwargs-map" + "symbol-list" + "kwargs-map"]] + (is (= transformed-params (gen/symbol-transform-param-name params))) + (is (= transformed-params (gen/symbol-transform-param-name + (:parameter-types (symbol-reflect-info "floor"))))))) + + +(deftest test-ndarray-transform-param-name + (let [params ["scala.collection.immutable.Map" + "scala.collection.Seq"] + transformed-params ["kwargs-map" "& nd-array-and-params"]] + (is (= transformed-params (gen/ndarray-transform-param-name params))) + (is (= transformed-params (gen/ndarray-transform-param-name + (:parameter-types (ndarray-reflect-info "sqrt"))))))) + +(deftest test-has-variadic? + (is (false? (gen/has-variadic? ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"]))) + (is (true? (gen/has-variadic? ["kwargs-map" "& nd-array-and-params"])))) + +(deftest test-increment-param-name + (is (= "foo-1" (gen/increment-param-name "foo"))) + (is (= "foo-2" (gen/increment-param-name "foo-1")))) + +(deftest test-rename-duplicate-params + (is (= ["foo" "bar" "baz"] (gen/rename-duplicate-params ["foo" "bar" "baz"]))) + (is (= ["foo" "bar" "bar-1"] (gen/rename-duplicate-params ["foo" "bar" "bar"])))) + +(deftest test-is-symbol-hand-gen? + (is (not (false? (gen/is-symbol-hand-gen? (symbol-reflect-info "max"))))) + (is (not (false? (gen/is-symbol-hand-gen? (symbol-reflect-info "Variable"))))) + (is (false? (gen/is-symbol-hand-gen? (symbol-reflect-info "sqrt"))))) + +(deftest test-is-ndarray-hand-gen? + (is (not (false? (gen/is-ndarray-hand-gen? (ndarray-reflect-info "zeros"))))) + (is (false? (gen/is-ndarray-hand-gen? (ndarray-reflect-info "sqrt"))))) + +(deftest test-public-by-name-and-param-count + (let [lrn-info (get (gen/public-by-name-and-param-count gen/symbol-public-to-gen) + (symbol "LRN"))] + (is (= 4 (-> lrn-info keys first))) + (is (= "LRN" (-> lrn-info vals ffirst :name str))))) + +(deftest test-symbol-vector-args + (is (= `(if (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"}))) (gen/symbol-vector-args))) + +(deftest test-symbol-map-args + (is (= `(if (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)) + (gen/symbol-map-args))) + +(deftest test-add-symbol-arities + (let [params (map symbol ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"]) + function-name (symbol "foo") + [ar1 ar2 ar3] (gen/add-symbol-arities params function-name)] + (is (= '([sym-name attr-map kwargs-map] + (foo + sym-name + (util/convert-symbol-map attr-map) + (util/empty-list) + (util/convert-symbol-map kwargs-map))) + ar1)) + (is (= '([sym-name kwargs-map-or-vec-or-sym] + (foo + sym-name + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + ar2) + (is (= '([kwargs-map-or-vec-or-sym] + (foo + nil + nil + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (util/empty-list) + (util/coerce-param + kwargs-map-or-vec-or-sym + #{"scala.collection.Seq"})) + (if + (clojure.core/map? kwargs-map-or-vec-or-sym) + (org.apache.clojure-mxnet.util/convert-symbol-map + kwargs-map-or-vec-or-sym) + nil)))) + ar3))) + +(deftest test-gen-symbol-function-arity + (let [op-name (symbol "$div") + op-values {1 [{:name (symbol "$div") + :return-type "org.apache.mxnet.Symbol," + :declaring-class "org.apache.mxnet.Symbol," + :parameter-types ["org.apache.mxnet.Symbol"], + :exception-types [], + :flags #{:public}} + {:name (symbol "$div") :return-type "org.apache.mxnet.Symbol," + :declaring-class "org.apache.mxnet.Symbol," + :parameter-types ["java.lang.Object"], + :exception-types [], + :flags #{:public}}]} + function-name (symbol "div")] + (is (= '(([sym sym-or-Object] + (util/coerce-return + (.$div + sym + (util/nil-or-coerce-param + sym-or-Object + #{"org.apache.mxnet.Symbol" "java.lang.Object"})))))) + (gen/gen-symbol-function-arity op-name op-values function-name)))) + +(deftest test-gen-ndarray-function-arity + (let [op-name (symbol "$div") + op-values {1 [{:name (symbol "$div") + :return-type "org.apache.mxnet.NDArray," + :declaring-class "org.apache.mxnet.NDArray," + :parameter-types ["float"], + :exception-types [], + :flags #{:public}} + {:name (symbol "$div") + :return-type "org.apache.mxnet.NDArray," + :declaring-class "org.apache.mxnet.NDArray," + :parameter-types ["org.apache.mxnet.NDArray"], + :exception-types [], + :flags #{:public}}]}] + (is (= '(([ndarray num-or-ndarray] + (util/coerce-return + (.$div + ndarray + (util/coerce-param + num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"})))))) + (gen/gen-ndarray-function-arity op-name op-values)))) + +(deftest test-write-to-file + (testing "symbol" + (let [fname "test/test-symbol.clj" + _ (gen/write-to-file [(first gen/all-symbol-functions)] + gen/symbol-gen-ns + fname) + good-contents (slurp "test/good-test-symbol.clj") + contents (slurp fname)] + (is (= good-contents contents)))) + + (testing "ndarray" + (let [fname "test/test-ndarray.clj" + _ (gen/write-to-file [(first gen/all-ndarray-functions)] + gen/ndarray-gen-ns + fname) + good-contents (slurp "test/good-test-ndarray.clj") + contents (slurp fname)] + (is (= good-contents contents))))) diff --git a/contrib/clojure-package/test/good-test-ndarray.clj b/contrib/clojure-package/test/good-test-ndarray.clj new file mode 100644 index 000000000000..8cdfce76fcb5 --- /dev/null +++ b/contrib/clojure-package/test/good-test-ndarray.clj @@ -0,0 +1,36 @@ +(ns org.apache.clojure-mxnet.ndarray + (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max + min repeat reverse set sort take to-array empty shuffle]) + (:import (org.apache.mxnet NDArray Shape))) + +;; Do not edit - this is auto-generated + +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + + + + +(defn + div + ([ndarray num-or-ndarray] + (util/coerce-return + (.$div + ndarray + (util/coerce-param + num-or-ndarray + #{"float" "org.apache.mxnet.NDArray"}))))) + diff --git a/contrib/clojure-package/test/good-test-symbol.clj b/contrib/clojure-package/test/good-test-symbol.clj new file mode 100644 index 000000000000..0f7479ad456f --- /dev/null +++ b/contrib/clojure-package/test/good-test-symbol.clj @@ -0,0 +1,38 @@ +(ns org.apache.clojure-mxnet.symbol + (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max + min repeat reverse set sort take to-array empty sin + get apply shuffle]) + (:require [org.apache.clojure-mxnet.util :as util]) + (:import (org.apache.mxnet Symbol))) + +;; Do not edit - this is auto-generated + +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + + + + +(defn + div + ([sym sym-or-object] + (util/coerce-return + (.$div + sym + (util/nil-or-coerce-param + sym-or-object + #{"org.apache.mxnet.Symbol" "java.lang.Object"}))))) + diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/callback_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/callback_test.clj new file mode 100644 index 000000000000..5957d209d2eb --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/callback_test.clj @@ -0,0 +1,17 @@ +(ns org.apache.clojure-mxnet.callback-test + (:require [org.apache.clojure-mxnet.callback :as callback] + [clojure.test :refer :all] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.ndarray :as ndarray])) + +(deftest test-speedometer + (let [speedometer (callback/speedometer 1) + metric (eval-metric/accuracy)] + (eval-metric/update metric [(ndarray/ones [2])] [(ndarray/ones [2 3])]) + ;;; only side effects of logging + (callback/invoke speedometer 0 1 metric) + (callback/invoke speedometer 0 2 metric) + (callback/invoke speedometer 0 3 metric) + (callback/invoke speedometer 0 10 metric) + (callback/invoke speedometer 0 50 metric) + (callback/invoke speedometer 0 100 metric))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj new file mode 100644 index 000000000000..372672b462bf --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj @@ -0,0 +1,77 @@ +(ns org.apache.clojure-mxnet.conv-test + (:require [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [clojure.test :refer :all] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.symbol :as sym] + [clojure.reflect :as r])) + +(def data-dir "data/") +(def batch-size 100) +(def num-epoch 1) + +(when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte"))) + (sh "./scripts/get_mnist_data.sh")) + +;;; Load the MNIST datasets +(def train-data (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") + :label (str data-dir "train-labels-idx1-ubyte") + :label-name "softmax_label" + :data-shape [1 28 28] + :label-shape [1 1 10] + :batch-size batch-size + :shuffle true + :flat false + :silent false + :seed 10})) + +(def test-data (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte") + :label (str data-dir "t10k-labels-idx1-ubyte") + :data-shape [1 28 28] + :batch-size batch-size + :flat false + :silent false})) +(defn get-symbol [] + (as-> (sym/variable "data") data + + (sym/convolution "conv1" {:data data :kernel [3 3] :num-filter 32 :stride [2 2]}) + (sym/batch-norm "bn1" {:data data}) + (sym/activation "relu1" {:data data :act-type "relu"}) + (sym/pooling "mp1" {:data data :kernel [2 2] :pool-type "max" :stride [2 2]}) + + + (sym/convolution "conv2" {:data data :kernel [3 3] :num-filter 32 :stride [2 2]}) + (sym/batch-norm "bn2" {:data data}) + (sym/activation "relu2" {:data data :act-type "relu"}) + (sym/pooling "mp2" {:data data :kernel [2 2] :pool-type "max" :stride [2 2]}) + + (sym/flatten "fl" {:data data}) + (sym/fully-connected "fc2" {:data data :num-hidden 10}) + (sym/softmax-output "softmax" {:data data}))) + + + +(deftest test-conv [] + (let [mod (m/module (get-symbol) )] + ;;; note only one function for training + (m/fit mod {:train-data train-data :eval-data test-data :num-epoch num-epoch + :fit-params (m/fit-params {:optimizer (optimizer/sgd {:learning-rate 0.1 + :momentum 0.9 + :wd 0.0001})})}) + + ;;high level predict (just a dummy call but it returns a vector of results + (m/predict mod {:eval-data test-data}) + + ;;;high level score (returs the eval values) + (let [score (m/score mod {:eval-data test-data :eval-metric (eval-metric/accuracy)})] + (println "Score" score) + (is (< 0.92 (last score)))))) + +(comment + + (require '[clojure.reflect :as r]) + (r/reflect train-data) + ) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/eval_metric_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/eval_metric_test.clj new file mode 100644 index 000000000000..a040d950acd1 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/eval_metric_test.clj @@ -0,0 +1,42 @@ +(ns org.apache.clojure-mxnet.eval-metric-test + (:require [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [clojure.test :refer :all] + [org.apache.clojure-mxnet.ndarray :as ndarray])) + +(defn test-eval-metric [test-metric metric-name labels preds metric-val] + (println "Testing eval metric" metric-name) + (let [metric test-metric] + (eval-metric/update metric labels preds) + (is (= [metric-name metric-val] (eval-metric/get metric))) + + (testing "get does not reset the metric" + (is (= [metric-name metric-val] (eval-metric/get metric)))) + + (testing "resetting the metric" + (eval-metric/reset metric) + (is (= [metric-name "NaN"] (map str (eval-metric/get metric))))) + + (testing "get-and-reset gets the metric and then resets it" + (eval-metric/update metric labels preds) + (is (= [metric-name metric-val] (eval-metric/get-and-reset metric))) + (is (= [metric-name "NaN"] (map str (eval-metric/get metric))))))) + +(deftest test-metrics + (doseq [[metric-fn metric-name labels preds metric-val] + [[(eval-metric/accuracy) "accuracy" [(ndarray/zeros [2])] [(ndarray/zeros [2 3])] 1.0] + [(eval-metric/top-k-accuracy 2) "top_k_accuracy" [(ndarray/zeros [2])] [(ndarray/zeros [2 3])] 1.0] + [(eval-metric/f1) "f1" [(ndarray/zeros [2])] [(ndarray/zeros [2 3])] 0.0] + [(eval-metric/perplexity) "Perplexity" [(ndarray/ones [2])] [(ndarray/ones [2 3])] 1.0] + [(eval-metric/mae) "mae" [(ndarray/ones [2])] [(ndarray/ones [2])] 0.0] + [(eval-metric/mse) "mse" [(ndarray/ones [2])] [(ndarray/ones [2])] 0.0] + [(eval-metric/rmse) "rmse" [(ndarray/ones [2])] [(ndarray/ones [2])] 0.0]]] + (test-eval-metric metric-fn metric-name labels preds metric-val))) + +(deftest test-custom-metric + (let [metric (eval-metric/custom-metric (fn [label pred] + (float + (- (apply + (ndarray/->vec label)) + (apply + (ndarray/->vec pred))))) + "my-metric")] + (eval-metric/update metric [(ndarray/ones [2])] [(ndarray/ones [2])]) + (is (= ["my-metric" 0.0] (eval-metric/get metric))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj new file mode 100644 index 000000000000..837f2f5a2653 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/executor_test.clj @@ -0,0 +1,60 @@ +(ns org.apache.clojure-mxnet.executor-test + (:require [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.random :as random] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.test-util :as test-util] + [clojure.test :refer :all])) + +(deftest test-bind + (let [shape [100 30] + lhs (sym/variable "lhs") + rhs (sym/variable "rhs") + ret (sym/+ lhs rhs)] + (is (= ["lhs" "rhs"] (sym/list-arguments ret))) + + (let [lhs-arr (random/uniform -10 10 shape) + rhs-arr (random/uniform -10 10 shape) + lhs-grad (ndarray/empty shape) + rhs-grad (ndarray/empty shape) + exec (sym/bind ret (context/default-context) [lhs-arr rhs-arr] [lhs-grad rhs-grad]) + exec2 (sym/bind ret (context/default-context) [lhs-arr rhs-arr]) + exec3 (sym/bind ret (context/default-context) {"rhs" rhs-arr "lhs" lhs-arr} {"lhs" lhs-grad "rhs" rhs-grad})] + (executor/forward exec) + (executor/forward exec2) + (executor/forward exec3) + (is (test-util/approx= 1e-6 (-> (ndarray/+ lhs-arr rhs-arr) ndarray/->vec) (-> (executor/outputs exec) first ndarray/->vec))) + (is (test-util/approx= 1e-6 (-> (ndarray/+ lhs-arr rhs-arr) ndarray/->vec) (-> (executor/outputs exec2) first ndarray/->vec))) + (is (test-util/approx= 1e-6 (-> (ndarray/+ lhs-arr rhs-arr) ndarray/->vec) (-> (executor/outputs exec3) first ndarray/->vec))) + + ;; test gradient + (let [out-grad (ndarray/ones shape) + lhs-grad2 out-grad + rhs-grad2 out-grad] + (executor/backward exec out-grad) + (is (test-util/approx= 1e-6 (ndarray/->vec lhs-grad) (ndarray/->vec lhs-grad2))) + (is (test-util/approx= 1e-6 (ndarray/->vec rhs-grad) (ndarray/->vec rhs-grad2))))))) + + +(deftest test-reshape + (let [x (sym/variable "x") + y (sym/fully-connected {:data x :num-hidden 4}) + exec (sym/simple-bind y (context/default-context) {"x" [5 4]}) + _ (executor/set-arg-arrays exec [1 1 0]) + new-exec (executor/reshape exec {"x" [3 4]})] + (executor/forward new-exec) + ;; test sub exec forward + (is (every? #(= 4.0 %) (->> (executor/outputs new-exec) + (map ndarray/->vec) + first))) + ;; test shared memory + (is (= [4.0 4.0 4.0]) (->> (executor/outputs exec) + (map ndarray/->vec) + first + (take 3))) + ;; test base exec forward + (executor/forward exec) + (is (every? #(= 4.0 %) (->> (executor/outputs exec) + (map ndarray/->vec) + first))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj new file mode 100644 index 000000000000..22a45da8e3ac --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj @@ -0,0 +1,150 @@ +(ns org.apache.clojure-mxnet.io-test + (:require [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.shape :as mx-shape] + [clojure.test :refer :all])) + +(deftest test-mnsit-iter-and-mnist-pack + (let [_ (when-not (.exists (io/file "data/train-images-idx3-ubyte")) + (sh "scripts/get_mnist_data.sh")) + params {:image "data/train-images-idx3-ubyte" + :label "data/train-labels-idx1-ubyte" + :data-shape [784] + :batch-size 100 + :shuffle 1 + :flat 1 + :silent 0 + :seed 10} + mnist-pack (mx-io/mnist-pack params)] + (is (= 600 (count (mx-io/batches mnist-pack)))) + + (let [mnist-iter (mx-io/iterator mnist-pack) + provide-data (mx-io/provide-data mnist-iter) + provide-label (mx-io/provide-label mnist-iter)] + (is (= [100 784] (-> provide-data first :shape))) + (is (= [100] (-> provide-label first :shape))) + (is (= 600 (mx-io/reduce-batches mnist-iter (fn [result batch] (inc result))))) + ;; test reset + (let [_ (mx-io/reset mnist-iter) + _ (mx-io/next mnist-iter) + label0 (-> (mx-io/iter-label mnist-iter) first (ndarray/->vec)) + data0 (-> (mx-io/iter-data mnist-iter) first (ndarray/->vec)) + _ (mx-io/next mnist-iter) + _ (mx-io/next mnist-iter) + _ (mx-io/next mnist-iter) + _ (mx-io/reset mnist-iter) + _ (mx-io/next mnist-iter) + label1 (-> (mx-io/iter-label mnist-iter) first (ndarray/->vec)) + data1 (-> (mx-io/iter-data mnist-iter) first (ndarray/->vec))] + (is (= label1 label0)) + (is (= data1 data0)))))) + + +(deftest test-image-record-iter + (let [_ (when-not (.exists (io/file "data/cifar/train.rec")) + (sh "scripts/get_cifar_data.sh")) + params {:path-imgrec "data/cifar/train.rec" + :label "data/cifar/cifar10_mean.bin" + :rand-crop false + :and-mirror false + :shuffle false + :data-shape [3 28 28] + :batch-size 100 + :preprocess-threads 4 + :prefetch-buffer 1} + img-rec-iter (mx-io/image-record-iter params) + nbatch 500] + (is (= [100 3 28 28] (-> (mx-io/provide-data img-rec-iter) first :shape))) + (is (= [100] (-> (mx-io/provide-label img-rec-iter) first :shape))) + (is (= nbatch (mx-io/reduce-batches img-rec-iter (fn [result batch] (inc result))))))) + +(deftest test-resize-iter + (let [_ (when-not (.exists (io/file "data/train-images-idx3-ubyte")) + (sh "scripts/get_mnist_data.sh")) + params {:image "data/train-images-idx3-ubyte" + :label "data/train-labels-idx1-ubyte" + :data-shape [784] + :batch-size 100 + :shuffle 1 + :flat 1 + :silent 0 + :seed 10} + mnist-iter (mx-io/mnist-iter params) + nbatch 400 + resize-iter (mx-io/resize-iter mnist-iter nbatch false)] + (is (= nbatch (mx-io/reduce-batches resize-iter (fn [result batch] (inc result))))) + (mx-io/reset resize-iter) + (is (= nbatch (mx-io/reduce-batches resize-iter (fn [result batch] (inc result))))))) + +(deftest test-prefetching-iter + (let [_ (when-not (.exists (io/file "data/train-images-idx3-ubyte")) + (sh "scripts/get_mnist_data.sh")) + params {:image "data/train-images-idx3-ubyte" + :label "data/train-labels-idx1-ubyte" + :data-shape [784] + :batch-size 100 + :shuffle 1 + :flat 1 + :silent 0 + :seed 10} + mnist-iter1 (mx-io/mnist-iter params) + mnist-iter2 (mx-io/mnist-iter params) + nbatch 600 + prefetch-iter (mx-io/prefetching-iter [mnist-iter1 mnist-iter2] + [{"data" "data1"} {"data" "data2"}] + [{"label" "label1"} {"label" "label2"}])] + (is (= nbatch (mx-io/reduce-batches prefetch-iter (fn [result batch] (inc result))))) + (let [provide-data (mx-io/provide-data prefetch-iter) + provide-label (mx-io/provide-label prefetch-iter)] + (is (= #{[100 784]} (into #{} (map :shape provide-data)))) + (is (= #{[100]} (into #{} (map :shape provide-label)))) + (mx-io/dispose prefetch-iter)))) + + +(deftest test-ndarray-iter + (let [shape0 [1000 2 2] + data [(ndarray/ones shape0) (ndarray/zeros shape0)] + shape1 [1000 1] + label [(ndarray/ones shape1)] + batch-data0 (ndarray/ones [128 2 2]) + batch-data1 (ndarray/zeros [128 2 2]) + batch-label (ndarray/ones [128 1])] + + ;; test pad + (let [data-iter0 (mx-io/ndarray-iter data {:label label + :data-batch-size 128 + :shuffle false + :last-batch-handle "pad"}) + nbatch0 8] + (is (= nbatch0 (count (mx-io/for-batches data-iter0 (fn [batch] 1))))) + (is (every? true? (mx-io/for-batches data-iter0 + (fn [batch] + (= batch-data0 + (first (mx-io/batch-data batch))))))) + (is (every? true? (mx-io/for-batches data-iter0 + (fn [batch] + (= batch-data1 + (second (mx-io/batch-data batch))))))) + (is (every? true? (mx-io/for-batches data-iter0 + (fn [batch] + (= batch-label + (first (mx-io/batch-label batch)))))))) + + ;; test discard + (let [data-iter1 (mx-io/ndarray-iter data {:label label + :data-batch-size 128 + :shuffle false + :last-batch-handle "discard"}) + nbatch1 7] + (is (= nbatch1 (mx-io/reduce-batches data-iter1 (fn [result batch] (inc result)))))) + + ;; test empty label for prediction + (let [data-iter2 (mx-io/ndarray-iter data {:data-batch-size 128 + :shuffle false + :last-batch-handle "discard"}) + nbatch2 7] + (is (= nbatch2 (mx-io/reduce-batches data-iter2 (fn [result batch] (inc result))))) + (is (= [] (mx-io/iter-init-label data-iter2)))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/kvstore_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/kvstore_test.clj new file mode 100644 index 000000000000..7be8751c1fdf --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/kvstore_test.clj @@ -0,0 +1,64 @@ +(ns org.apache.clojure-mxnet.kvstore-test + (:require [org.apache.clojure-mxnet.kvstore :as kvstore] + [clojure.test :refer :all] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.context :as context])) + +(deftest test-init-and-pull + (let [kv (kvstore/create) + shape [2 1] + out (ndarray/zeros shape)] + (-> kv + (kvstore/init "3" (ndarray/ones shape)) + (kvstore/pull "3" out)) + (is (= [1.0 1.0] (ndarray/->vec out))))) + +(deftest test-push-and-pull + (let [kv (kvstore/create) + shape [2 1] + out (ndarray/zeros shape)] + (-> kv + (kvstore/init "3" (ndarray/ones shape)) + (kvstore/push "3" (ndarray/* (ndarray/ones shape) 4)) + (kvstore/pull "3" out)) + (is (= [4.0 4.0] (ndarray/->vec out))))) + +(deftest test-aggregate + (let [shape [4 4] + ks ["b" "c" "d"] + kv (kvstore/create) + num-devs 4 + devs (mapv (fn [_] (context/cpu)) (range num-devs)) + vals (mapv #(ndarray/ones shape {:ctx %}) devs)] + (-> kv + (kvstore/init "a" (ndarray/zeros shape)) + (kvstore/init ks [(ndarray/zeros shape) (ndarray/zeros shape) (ndarray/zeros shape)]) + (kvstore/push "a" vals) + (kvstore/pull "a" vals)) + (is (= 0.0 (->> vals + (mapv ndarray/->vec) + flatten + (map #(- % num-devs)) + (apply +)))) + (let [result (for [k ks] + (let [tmp-vals (mapv #(ndarray/* (ndarray/ones shape {:ctx %}) 2.0) devs)] + (-> kv + (kvstore/push k tmp-vals) + (kvstore/pull k tmp-vals)) + (map ndarray/->vec tmp-vals)))] + (is (= 0.0 (->> result + (flatten) + (map #(- % (* num-devs 2))) + (apply +))))))) + +(deftest test-type + (is (= "local" (-> (kvstore/create "local") + (kvstore/type))))) + +(deftest test-get-numworkers + (is (= 1 (-> (kvstore/create "local") + (kvstore/num-workers))))) + +(deftest test-get-rank + (is (= 0 (-> (kvstore/create "local") + (kvstore/rank))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj new file mode 100644 index 000000000000..89ab27d108f8 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj @@ -0,0 +1,311 @@ +(ns org.apache.clojure-mxnet.module-test + (:require [clojure.java.io :as io] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.monitor :as monitor] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.util :as util] + [clojure.spec.alpha :as s] + [clojure.test :refer :all] + [clojure.reflect :as r] + [clojure.string :as string])) + + +(deftest test-model-dtype + (let [dtype dtype/FLOAT32 + dshape [3 8 7] + s (sym/variable "data") + s (sym/activation "act" {"__layout__" "TNC"} {:data s :act_type "relu"}) + + mod (m/module s ["data"] nil [(context/cpu 0) (context/cpu 1)])] + (-> mod + (m/bind {:data-shapes [{:name "data" :shape dshape :dtype dtype :layout "TNC"}]}) + (m/init-params) + (m/forward {:data [(ndarray/ones dshape {:dtype dtype})]}) + (m/backward[(ndarray/ones dshape {:dtype dtype})])) + (let [outputs (-> mod (m/outputs) flatten)] + (is (every? #(= dtype/FLOAT32 (ndarray/dtype %)) outputs))))) + +(deftest test-module-input-grads + (let [a (sym/variable "a" {:kwargs {"__layout__" "NC"}}) + b (sym/variable "b" {:kwargs {"__layout__" "NC"}}) + c (sym/variable "c" {:kwargs {"__layout__" "NC"}} ) + c (sym/+ a (sym/+ (sym/* b 2) (sym/* c 3))) + mod (m/module c ["b" "c" "a"] nil [(context/cpu 0) (context/cpu 1)])] + (-> mod + (m/bind {:data-shapes [{:name "b" :shape [5 5] :layout "NT"} + {:name "c" :shape [5 5] :layout "NT"} + {:name "a" :shape [5 5] :layout "NT"}] + :inputs-need-grad true}) + (m/init-params) + (m/forward {:data [(ndarray/ones [5 5]) + (ndarray/ones [5 5]) + (ndarray/ones [5 5])] + :label nil + :index nil + :pad 0}) + (m/backward [(ndarray/ones [5 5])])) + (let [[a-grad b-grad c-grad] (m/input-grads-merged mod)] + (is (every? #(= 1.0 %) (ndarray/->vec a-grad))) + (is (every? #(= 2.0 %) (ndarray/->vec b-grad))) + (is (every? #(= 3.0 %) (ndarray/->vec c-grad)))))) + +(deftest test-module-layout + (let [s (sym/variable "data") + s (sym/activation "act "{"__layout__" "TNC"} {:data s :act_type "relu"}) + dshape [3 8 7] + mod (m/module s ["data"] nil [(context/cpu 0) (context/cpu 1)])] + (-> mod + (m/bind {:data-shapes [{:name "data" :shape dshape :dtype dtype/FLOAT32 :layout "TNC"}]}) + (m/init-params) + (m/forward {:data [(ndarray/ones dshape)] + :label nil + :index nil + :pad 0}) + (m/backward[(ndarray/ones dshape)])) + (let [outputs-merged (m/outputs-merged mod) + outputs (m/outputs mod) + hd-shape [3 4 7]] + (is (= dshape (-> outputs-merged first (ndarray/shape) (ndarray/->vec)))) + (is (every? #(= hd-shape (-> % ndarray/shape ndarray/->vec)) (flatten outputs)))))) + +(deftest test-module-save-load-single-device + (let [s (sym/variable "data") + s (sym/fully-connected {:data s :num-hidden 100}) + ;; single device + mod (m/module s {:data-names ["data"] :label-names nil})] + (-> mod + (m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]}) + (m/init-params) + (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 :momentum 0.9})}) + (m/update) + (m/save-checkpoint {:prefix "test" :epoch 0 :save-opt-states true})) + + (let [mod2 (m/load-checkpoint {:prefix "test" :epoch 0 :load-optimizer-states true})] + (-> mod2 + (m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]}) + (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 :momentum 0.9})})) + (is (= (-> mod m/symbol sym/to-json) (-> mod2 m/symbol sym/to-json) )) + (is (= (-> mod m/params first) (-> mod2 m/params first)))))) + +(deftest test-module-save-load-multi-device + (let [s (sym/variable "data") + s (sym/fully-connected {:data s :num-hidden 100}) + ;; multi device + mod (m/module s {:data-names ["data"] :label-names nil + :contexts [(context/cpu 0) (context/cpu 1)]})] + (-> mod + (m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]}) + (m/init-params) + (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 :momentum 0.9})}) + (m/update) + (m/save-checkpoint {:prefix "test" :epoch 0 :save-opt-states true})) + + (let [mod2 (m/load-checkpoint {:prefix "test" :epoch 0 :load-optimizer-states true})] + (-> mod2 + (m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]}) + (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 :momentum 0.9})})) + (is (= (-> mod m/symbol sym/to-json) (-> mod2 m/symbol sym/to-json) )) + (is (= (-> mod m/params first) (-> mod2 m/params first)))))) + +(deftest test-module-reshape + (let [s (sym/variable "data") + s (sym/fully-connected "fc" {:data s :num-hidden 20}) + dshape [7 20] + mod (m/module s ["data"] nil [(context/cpu 0) (context/cpu 1)])] + (-> mod + (m/bind {:data-shapes [{:name "data" :shape dshape :layout "NT"}]}) + (m/init-params) + (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 1.0})}) + (m/forward {:data [(ndarray/ones dshape)] :label nil :index nil :pad 0}) + (m/backward [(ndarray/ones dshape)]) + (m/update)) + (is (= dshape (-> (m/outputs-merged mod) first ndarray/shape mx-shape/->vec))) + (is (every? #(= -1.0 %) (-> (m/params mod) (first) (get "fc_bias") (ndarray/->vec)))) + + (let [dshape [14 20]] + (-> mod + (m/reshape [{:name "data" :shape dshape :layout "NT"}]) + (m/forward {:data [(ndarray/ones dshape)] :label nil :index nil :pad 0}) + (m/backward [(ndarray/ones dshape)]) + (m/update)) + (is (= dshape (-> (m/outputs-merged mod) first ndarray/shape mx-shape/->vec))) + (is (every? #(< 1e-3 (- 3 %)) (-> mod m/params first (get "fc_bias") (ndarray/->vec))))))) + +(deftest test-set-params + (let [data (ndarray/array [0.05 0.1] [1 1 1 2]) + label (ndarray/array [0.01 0.99] [1 1 1 2]) + train-data (mx-io/ndarray-iter [data] {:label [label] :label-name "softmax_label"}) + x (as-> (sym/variable "data") v + (sym/fully-connected "fc_0" {:data v :num-hidden 2}) + (sym/activation "act_0" {:data v :act-type "sigmoid"}) + (sym/fully-connected "fc_1" {:data v :num-hidden 2}) + (sym/activation "act_1" {:data v :act-type "sigmoid"}) + (sym/linear-regression-output "softmax" {:data v :grad-scale 2})) + + mod (m/module x)] + (m/bind mod {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)} ) + + (let [arg-params-correct {"fc_0_weight" (ndarray/array [0.15 0.2 0.25 0.3] [2 2]) + "fc_0_bias" (ndarray/array [0.35 0.35] [2]) + "fc_1_weight" (ndarray/array [0.4 0.45 05 0.55] [2 2]) + "fc_1_bias" (ndarray/array [0.6 0.6] [2])} + arg-params-missing {"fc_0_weight" (ndarray/array [0.15 0.2 0.25 0.3] [2 2]) + "fc_0_bias" (ndarray/array [0.35 0.35] [2]) + "fc_1_weight" (ndarray/array [0.4 0.45 05 0.55] [2 2])} + arg-params-extra {"fc_0_weight" (ndarray/array [0.15 0.2 0.25 0.3] [2 2]) + "fc_0_bias" (ndarray/array [0.35 0.35] [2]) + "fc_1_weight" (ndarray/array [0.4 0.45 05 0.55] [2 2]) + "fc_1_bias" (ndarray/array [0.6 0.6] [2]) + "fc_2_weight" (ndarray/array [0.6 0.6] [2])}] + (m/set-params mod {:arg-params arg-params-correct :force-init true}) + (m/set-params mod {:arg-params arg-params-missing :allow-missing true}) + (m/set-params mod {:arg-params arg-params-extra :allow-extra true})))) + +(deftest test-monitor + (let [data (ndarray/array [0.05 0.1] [1 1 1 2]) + label (ndarray/array [0.01 0.99] [1 1 1 2]) + train-data (mx-io/ndarray-iter [data] {:label [label] :label-name "softmax_label"}) + x (as-> (sym/variable "data") v + (sym/fully-connected "fc_0" {:data v :num-hidden 2}) + (sym/activation "act_0" {:data v :act-type "sigmoid"}) + (sym/fully-connected "fc_1" {:data v :num-hidden 2}) + (sym/activation "act_1" {:data v :act-type "sigmoid"}) + (sym/linear-regression-output "softmax" {:data v :grad-scale 2})) + ;; create monitor + mon (monitor/monitor 1 (fn [x] + (ndarray/div (ndarray/sum (ndarray/abs x)) + (mx-shape/product (ndarray/shape x))))) + mod (m/module x {:contexts [(context/cpu 0)]}) + arg-params {"fc_0_weight" (ndarray/array [0.15 0.2 0.25 0.3] [2 2]) + "fc_0_bias" (ndarray/array [0.35 0.35] [2]) + "fc_1_weight" (ndarray/array [0.4 0.45 05 0.55] [2 2]) + "fc_1_bias" (ndarray/array [0.6 0.6] [2])} + data-batch (mx-io/next train-data)] + (-> mod + (m/bind {:data-shapes [{:name "data", :shape [1 1 1 2]}] + :label-shapes [{:name "softmax_label", :shape [1 1 1 2]}]}) + (m/install-monitor mon) + (m/init-params {:arg-params arg-params})) + (monitor/tic mon) + (m/forward-backward mod data-batch) + (let [result (monitor/toc mon) + freq (->> result + (map (fn [v] (as-> (second v) ? + (clojure.string/split ? #"_") + (take 2 ?) + (clojure.string/join "_" ?)))) + (frequencies)) + expected-freq {"act_0" 2 "act_1" 2 "data" 1 "fc_0" 6 "fc_1" 6}] + (is (= expected-freq (select-keys freq (keys expected-freq))))))) + + +(deftest test-forward-reshape + (let [num-class 10 + data1 (sym/variable "data1") + data2 (sym/variable "data2") + conv1 (sym/convolution {:data data1 :kernel [2 2] :num-filter 2 :stride [2 2]}) + conv2 (sym/convolution {:data data2 :kernel [3 3] :num-filter 3 :stride [1 1]}) + pooling1 (sym/pooling {:data conv1 :kernel [2 2] :pool-type "avg" :stride [1 1]}) + pooling2 (sym/pooling {:data conv2 :kernel [2 2] :pool-type "max" :stride [1 1]}) + flatten1 (sym/flatten {:data pooling1}) + flatten2 (sym/flatten {:data pooling2}) + sum (sym/+ (sym/sum {:data flatten1 :axis 1}) + (sym/sum {:data flatten2 :axis 1})) + fc (sym/fully-connected {:data sum :num-hidden num-class}) + my-sym (sym/softmax-output "softmax" {:data fc}) + + d-shape1 [10 3 64 64] + d-shape2 [10 3 32 32] + l-shape [10] + mod (m/module my-sym {:data-names ["data1" "data2"]}) + data-batch {:data [(ndarray/random-uniform 0 9 (str (mx-shape/->shape d-shape1))) + (ndarray/random-uniform 5 15 (str (mx-shape/->shape d-shape2)))] + :label [(ndarray/ones l-shape)] + :index nil + :pad 0}] + + ;; train with the original shapes + (-> mod + (m/bind {:data-shapes [{:name "data1" :shape d-shape1} + {:name "data2" :shape d-shape2}] + :label-shapes [{:name "softmax_label" :shape l-shape :layout "N"}]}) + (m/init-params) + (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1})}) + (m/forward data-batch)) + (is (= [(first l-shape) num-class]) (-> (m/outputs-merged mod) first (ndarray/shape) (mx-shape/->vec))) + (-> mod + (m/backward) + (m/update)) + + (let [d-shape1 [3 3 64 64] + d-shape2 [3 3 32 32] + l-shape [3] + data-batch-2 {:data [(ndarray/random-uniform 0 9 (str (mx-shape/->shape d-shape1))) + (ndarray/random-uniform 5 15 (str (mx-shape/->shape d-shape2)))] + :label [(ndarray/ones l-shape)] + :index nil + :pad 0}] + (-> mod + (m/forward data-batch)) + (is (= [(first l-shape) num-class]) (-> (m/outputs-merged mod) first (ndarray/shape) (mx-shape/->vec))) + (-> mod + (m/backward) + (m/update))) + + (let [d-shape1 [20 3 64 64] + d-shape2 [20 3 32 32] + l-shape [20] + data-batch-2 {:data [(ndarray/random-uniform 3 5 (str (mx-shape/->shape d-shape1))) + (ndarray/random-uniform 10 25 (str (mx-shape/->shape d-shape2)))] + :label [(ndarray/ones l-shape)] + :index nil + :pad 0}] + (-> mod + (m/forward data-batch)) + (is (= [(first l-shape) num-class]) (-> (m/outputs-merged mod) first (ndarray/shape) (mx-shape/->vec))) + (-> mod + (m/backward) + (m/update))) + + ;; train with both different batch sizes and data shapes + (let [d-shape1 [20 3 120 120] + d-shape2 [20 3 32 64] + l-shape [20] + data-batch {:data [(ndarray/random-uniform 0 9 (str (mx-shape/->shape d-shape1))) + (ndarray/random-uniform 15 25 (str (mx-shape/->shape d-shape2)))] + :label [(ndarray/ones l-shape)] + :index nil + :pad 0}] + (-> mod + (m/forward data-batch)) + (is (= [(first l-shape) num-class]) (-> (m/outputs-merged mod) first (ndarray/shape) (mx-shape/->vec))) + (-> mod + (m/backward) + (m/update))) + (let [d-shape1 [5 3 28 40] + d-shape2 [5 3 24 16] + l-shape [5] + data-batch {:data [(ndarray/random-uniform 0 9 (str (mx-shape/->shape d-shape1))) + (ndarray/random-uniform 15 25 (str (mx-shape/->shape d-shape2)))] + :label [(ndarray/ones l-shape)] + :index nil + :pad 0}] + (-> mod + (m/forward data-batch)) + (is (= [(first l-shape) num-class]) (-> (m/outputs-merged mod) first (ndarray/shape) (mx-shape/->vec))) + (-> mod + (m/backward) + (m/update))))) + + +(comment + + (m/data-shapes x) + ) + diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj new file mode 100644 index 000000000000..99dfb63d7ac7 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj @@ -0,0 +1,464 @@ +(ns org.apache.clojure-mxnet.ndarray-test + (:require [org.apache.clojure-mxnet.base :as base] + [org.apache.clojure-mxnet.context :as ctx] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.ndarray :as ndarray :refer [->vec zeros ones += -= *= full shape]] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.test-util :as test-util] + [clojure.test :refer :all])) + +(deftest test->vec + (is (= [0.0 0.0 0.0 0.0] (->vec (zeros [2 2]))))) + +(deftest test-to-array + (is (= [0.0 0.0 0.0 0.0]) (vec (ndarray/to-array (zeros [2 2]))))) + +(deftest test-to-scalar + (is (= 0.0 (ndarray/to-scalar (zeros [1])))) + (is (= 1.0 (ndarray/to-scalar (ones [1])))) + (is (thrown-with-msg? Exception #"The current array is not a scalar" + (ndarray/to-scalar (zeros [1 1]))))) + +(deftest test-size-and-shape + (let [m (zeros [4 1])] + (is (= (mx-shape/->shape [4 1]) (ndarray/shape m))) + (is (= 4 (ndarray/size m))))) + +(deftest test-dtype + (is (= base/MX_REAL_TYPE (ndarray/dtype (zeros [3 2]))))) + +(deftest test-set-scalar-value + (is (= [10.0 10.0] (-> (ndarray/empty [2 1]) + (ndarray/set 10) + (->vec))))) + +(deftest test-copy-from-vector + (is (= [1.0 2.0 3.0 4.0] (-> (ndarray/empty [4 1]) + (ndarray/set [1 2 3 4]) + (->vec))))) + +(deftest test-plus + (let [ndzeros (zeros [2 1]) + ndones (ndarray/+ ndzeros 1)] + (is (= [1.0 1.0] (->vec ndones))) + (is (= [2.0 2.0] (->vec (ndarray/+ ndones 1)))) + (is (= [1.0 1.0] (->vec ndones))) + ;;; += mutuates + (is (= [2.0 2.0]) (->vec (+= ndones 1))) + (is (= [2.0 2.0]) (->vec ndones)))) + +(deftest test-minus + (let [ndones (ones [2 1]) + ndzeros (ndarray/- ndones 1)] + (is (= [0.0 0.0] (->vec ndzeros))) + (is (= [-1.0 -1.0] (->vec (ndarray/- ndzeros 1)))) + (is (= [0.0 0.0] (->vec ndzeros))) + ;;; += mutuates + (is (= [-1.0 -1.0]) (->vec (-= ndzeros 1))) + (is (= [-1.0 -1.0]) (->vec ndzeros)))) + +(deftest test-multiplication + (let [ndones (ones [2 1]) + ndtwos (ndarray/* ndones 2)] + (is (= [2.0 2.0] (->vec ndtwos))) + (is (= [1.0 1.0] (->vec (ndarray/* ndones ndones)))) + (is (= [4.0 4.0] (->vec (ndarray/* ndtwos ndtwos)))) + ;; *= mutates + (is (= [4.0 4.0] (->vec (*= ndtwos ndtwos)))) + (is (= [4.0 4.0] (->vec ndtwos))))) + +(deftest test-division + (let [ndones (ones [2 1]) + ndzeros (ndarray/- ndones 1) + ndhalves (ndarray/div ndones 2)] + (is (= [0.5 0.5] (->vec ndhalves))) + (is (= [1.0 1.0] (->vec (ndarray/div ndhalves ndhalves)))) + (is (= [1.0 1.0] (->vec (ndarray/div ndones ndones)))) + (is (= [0.0 0.0] (->vec (ndarray/div ndzeros ndones)))) + ;; div= mutates + (is (= [1.0 1.0] (->vec (ndarray/div= ndhalves ndhalves)))) + (is (= [1.0 1.0] (->vec ndhalves))))) + +(deftest test-full + (let [nda (full [1 2] 3)] + (is (= (shape nda) (mx-shape/->shape [1 2]))) + (is (= [3.0 3.0] (->vec nda))))) + +(deftest test-clip + (let [nda (-> (ndarray/empty [3 2]) + (ndarray/set [1 2 3 4 5 6]))] + (is (= [2.0 2.0 3.0 4.0 5.0 5.0] (->vec (ndarray/clip nda 2 5)))))) + +(deftest test-sqrt + (let [nda (-> (ndarray/empty [4 1]) + (ndarray/set [0 1 4 9]))] + (is (= [0.0 1.0 2.0 3.0] (->vec (ndarray/sqrt nda)))))) + +(deftest test-rsqrt + (let [nda (ndarray/array [1.0 4.0] [2 1])] + (is (= [1.0 0.5] (->vec (ndarray/rsqrt nda)))))) + +(deftest test-norm + (let [nda (-> (ndarray/empty [3 1]) + (ndarray/set [1 2 3])) + normed (ndarray/norm nda)] + (is (= [1] (mx-shape/->vec (shape normed)))) + (is (test-util/approx= 1e-4 (Math/sqrt 14.0) (ndarray/to-scalar normed))))) + +(deftest test-one-hot-encode + (let [nda1 (ndarray/array [1 0 2] [3]) + nda2 (ndarray/empty [3 3]) + res (ndarray/onehot-encode nda1 nda2)] + (is (= [3 3] (mx-shape/->vec (shape res)))) + (is (= [0.0 1.0 0.0 + 1.0 0.0 0.0 + 0.0 0.0 1.0] (->vec res))))) + +(deftest test-dot + (let [nda1 (ndarray/array [1 2] [1 2]) + nda2 (ndarray/array [3 4] [2 1]) + res (ndarray/dot nda1 nda2)] + (is (= [1 1] (mx-shape/->vec (shape res)))) + (is (= [11.0] (->vec res))))) + + +(deftest test-arrange + (let [start 0 + stop 5 + step 0.5 + repeat 2] + (is (= [0.0 0.0 0.5 0.5 1.0 1.0 1.5 1.5 2.0 2.0 2.5 2.5 3.0 3.0 3.5 3.5 4.0 4.0 4.5 4.5] + (->vec (ndarray/arange start stop {:step step :repeat repeat})))))) + +(deftest test-power + (let [nda (ndarray/array [3 5] [2 1])] + + (let [nda-power-1 (ndarray/power 2 nda)] + (is (= [2 1] (-> nda-power-1 shape mx-shape/->vec))) + (is (= [8.0 32.0] (->vec nda-power-1)))) + + (let [nda-power-2 (ndarray/power nda 2)] + (is (= [2 1] (-> nda-power-2 shape mx-shape/->vec))) + (is (= [9.0 25.0] (->vec nda-power-2)))) + + (let [nda-power-3 (ndarray/power nda nda)] + (is (= [2 1] (-> nda-power-3 shape mx-shape/->vec))) + (is (= [27.0 3125.0] (->vec nda-power-3)))) + + (let [nda-power-4 (ndarray/** nda 2)] + (is (= [2 1] (-> nda-power-4 shape mx-shape/->vec))) + (is (= [9.0 25.0] (->vec nda-power-4)))) + + (let [nda-power-5 (ndarray/** nda nda)] + (is (= [2 1] (-> nda-power-5 shape mx-shape/->vec))) + (is (= [27.0 3125.0] (->vec nda-power-5)))) + + (let [_ (ndarray/**= nda 2)] + (is (= [2 1] (-> nda shape mx-shape/->vec))) + (is (= [9.0 25.0] (->vec nda)))) + + (let [_ (ndarray/set nda [3 5]) + _ (ndarray/**= nda nda)] + (is (= [2 1] (-> nda shape mx-shape/->vec))) + (is (= [27.0 3125.0] (->vec nda)))))) + + +(deftest test-equal + (let [nda1 (ndarray/array [1 2 3 5] [2 2]) + nda2 (ndarray/array [1 4 3 6] [2 2])] + + (is (= [2 2] (-> (ndarray/equal nda1 nda2) shape mx-shape/->vec))) + (is (= [1.0 0.0 1.0 0.0] (->vec (ndarray/equal nda1 nda2)))) + + (is (= [2 2] (-> (ndarray/equal nda1 3) shape mx-shape/->vec))) + (is (= [0.0 0.0 1.0 0.0] (->vec (ndarray/equal nda1 3)))))) + + +(deftest test-not-equal + (let [nda1 (ndarray/array [1 2 3 5] [2 2]) + nda2 (ndarray/array [1 4 3 6] [2 2])] + + (is (= [2 2] (-> (ndarray/not-equal nda1 nda2) shape mx-shape/->vec))) + (is (= [0.0 1.0 0.0 1.0] (->vec (ndarray/not-equal nda1 nda2)))) + + (is (= [2 2] (-> (ndarray/not-equal nda1 3) shape mx-shape/->vec))) + (is (= [1.0 1.0 0.0 1.0] (->vec (ndarray/not-equal nda1 3)))))) + +(deftest test-greater + (let [nda1 (ndarray/array [1 2 4 5] [2 2]) + nda2 (ndarray/array [1 4 3 6] [2 2])] + + (is (= [2 2] (-> (ndarray/> nda1 nda2) shape mx-shape/->vec))) + (is (= [0.0 0.0 1.0 0.0] (->vec (ndarray/> nda1 nda2)))) + + (is (= [2 2] (-> (ndarray/> nda1 2) shape mx-shape/->vec))) + (is (= [0.0 0.0 1.0 1.0] (->vec (ndarray/> nda1 2)))))) + + +(deftest test-greater-equal + (let [nda1 (ndarray/array [1 2 4 5] [2 2]) + nda2 (ndarray/array [1 4 3 6] [2 2])] + + (is (= [2 2] (-> (ndarray/>= nda1 nda2) shape mx-shape/->vec))) + (is (= [1.0 0.0 1.0 0.0] (->vec (ndarray/>= nda1 nda2)))) + + (is (= [2 2] (-> (ndarray/>= nda1 2) shape mx-shape/->vec))) + (is (= [0.0 1.0 1.0 1.0] (->vec (ndarray/>= nda1 2)))))) + +(deftest test-lesser + (let [nda1 (ndarray/array [1 2 4 5] [2 2]) + nda2 (ndarray/array [1 4 3 6] [2 2])] + + (is (= [2 2] (-> (ndarray/< nda1 nda2) shape mx-shape/->vec))) + (is (= [0.0 1.0 0.0 1.0] (->vec (ndarray/< nda1 nda2)))) + + (is (= [2 2] (-> (ndarray/< nda1 2) shape mx-shape/->vec))) + (is (= [1.0 0.0 0.0 0.0] (->vec (ndarray/< nda1 2)))))) + +(deftest test-lesser-equal + (let [nda1 (ndarray/array [1 2 4 5] [2 2]) + nda2 (ndarray/array [1 4 3 6] [2 2])] + + (is (= [2 2] (-> (ndarray/<= nda1 nda2) shape mx-shape/->vec))) + (is (= [1.0 1.0 0.0 1.0] (->vec (ndarray/<= nda1 nda2)))) + + (is (= [2 2] (-> (ndarray/< nda1 2) shape mx-shape/->vec))) + (is (= [1.0 1.0 0.0 0.0] (->vec (ndarray/<= nda1 2)))))) + + +(deftest test-choose-element-0index + (let [nda (ndarray/array [1 2 3 4 6 5] [2 3]) + indices (ndarray/array [0 1] [2]) + res (ndarray/choose-element-0index nda indices)] + (is (= [1.0 6.0] (->vec res))))) + +(deftest test-copy-to + (let [source (ndarray/array [1 2 3] [1 3]) + dest (ndarray/empty [1 3]) + _ (ndarray/copy-to source dest)] + (is (= [1 3] (-> dest shape mx-shape/->vec))) + (is (= [1.0 2.0 3.0] (->vec dest))))) + +(deftest test-abs + (let [nda (ndarray/array [-1 -2 3] [3 1])] + (is (= [1.0 2.0 3.0] (->vec (ndarray/abs nda)))))) + +(deftest test-sign + (let [nda (ndarray/array [-1 -2 3] [3 1])] + (is (= [-1.0 -1.0 1.0] (->vec (ndarray/sign nda)))))) + +(deftest test-round + (let [nda (ndarray/array [1.5 2.1 3.7] [3 1])] + (is (= [2.0 2.0 4.0] (->vec (ndarray/round nda)))))) + +(deftest test-ceil + (let [nda (ndarray/array [1.5 2.1 3.7] [3 1])] + (is (= [2.0 3.0 4.0] (->vec (ndarray/ceil nda)))))) + +(deftest test-floor + (let [nda (ndarray/array [1.5 2.1 3.7] [3 1])] + (is (= [1.0 2.0 3.0] (->vec (ndarray/floor nda)))))) + +(deftest test-square + (let [nda (ndarray/array [1 2 3] [3 1])] + (is (= [1.0 4.0 9.0] (->vec (ndarray/square nda)))))) + +(deftest test-exp + (let [nda (ones [1])] + (is (test-util/approx= 1e-3 2.71828 (ndarray/to-scalar (ndarray/exp nda)))))) + +(deftest test-log + (let [nda (-> (ndarray/empty [1]) + (ndarray/set 10))] + (is (test-util/approx= 1e-3 2.30258 (ndarray/to-scalar (ndarray/log nda)))))) + +(deftest test-cos + (let [nda (-> (ndarray/empty [1]) + (ndarray/set 12))] + (is (test-util/approx= 1e-3 0.8438539 (ndarray/to-scalar (ndarray/cos nda)))))) + +(deftest test-sin + (let [nda (-> (ndarray/empty [1]) + (ndarray/set 12))] + (is (test-util/approx= 1e-3 -0.536572918 (ndarray/to-scalar (ndarray/sin nda)))))) + +(deftest test-max + (let [nda (ndarray/array [1.5 2.1 3.7] [3 1])] + (is (test-util/approx= 1e-3 3.7 (ndarray/to-scalar (ndarray/max nda)))))) + +(deftest test-maximum + (let [nda1 (ndarray/array [1.5 2.1 3.7] [3 1]) + nda2 (ndarray/array [4 1 3.5] [3 1]) + res (ndarray/maximum nda1 nda2)] + (is (= [3 1] (-> res shape mx-shape/->vec))) + (is (test-util/approx= 1e-3 [4.0 2.1 3.7] (->vec res))))) + +(deftest test-min + (let [nda (ndarray/array [1.5 2.1 3.7] [3 1])] + (is (test-util/approx= 1e-3 1.5 (ndarray/to-scalar (ndarray/min nda)))))) + +(deftest test-minimum + (let [nda1 (ndarray/array [1.5 2.1 3.7] [3 1]) + nda2 (ndarray/array [4 1 3.5] [3 1]) + res (ndarray/minimum nda1 nda2)] + (is (= [3 1] (-> res shape mx-shape/->vec))) + (is (test-util/approx= 1e-3 [1.5 1.0 3.5] (->vec res))))) + +(deftest test-sum + (let [nda (ndarray/array [1 2 3 4] [2 2])] + (is (test-util/approx= 1e-3 10.0 (ndarray/to-scalar (ndarray/sum nda)))))) + +(deftest test-argmax-channel + (let [nda (ndarray/array [1 2 4 3] [2 2]) + argmax (ndarray/argmax-channel nda)] + (is (= [2] (-> argmax shape mx-shape/->vec))) + (is (= [1.0 0.0] (->vec argmax))))) + + +(deftest test-concatenate-axis-0 + (let [nda1 (ndarray/array [1 2 4 3 3 3] [2 3]) + nda2 (ndarray/array [8 7 6] [1 3]) + res (ndarray/concatenate [nda1 nda2])] + (is (= [3 3] (-> res shape mx-shape/->vec))) + (is (= [1.0 2.0 4.0 3.0 3.0 3.0 8.0 7.0 6.0] (->vec res))))) + +(deftest test-concatenate-axis-1 + (let [nda1 (ndarray/array [1 2 3 4] [2 2]) + nda2 (ndarray/array [5 6] [2 1]) + res (ndarray/concatenate [nda1 nda2] {:axis 1})] + (is (= [2 3] (-> res shape mx-shape/->vec))) + (is (= [1.0 2.0 5.0 3.0 4.0 6.0] (->vec res))))) + +(deftest test-transpose + (let [nda (ndarray/array [1 2 4 3 3 3] [2 3])] + (is (= [1.0 2.0 4.0 3.0 3.0 3.0] (->vec nda))) + (is (= [3 2] (-> (ndarray/t nda) shape mx-shape/->vec))) + (is (= [1.0 3.0 2.0 3.0 4.0 3.0] (->vec (ndarray/t nda)))))) + +(def file-seq-num (atom 0)) + +(deftest test-save-and-load-with-names + (let [filename (str (System/getProperty "java.io.tmpdir") "/ndarray" (swap! file-seq-num inc) ".bin") + nda (ndarray/array [1 2 3] [3 1]) + _ (ndarray/save filename {"local" nda}) + load-map (ndarray/load filename)] + (is (= ["local"] (keys load-map))) + (is (= 1 (count (vals load-map)))) + (is (= [3 1] (-> (get load-map "local") shape mx-shape/->vec))) + (is (= [1.0 2.0 3.0] (->vec (get load-map "local")))))) + +(deftest test-save-to-file-and-load-from-file + (let [filename (str (System/getProperty "java.io.tmpdir") "/ndarray" (swap! file-seq-num inc) ".bin") + nda (ndarray/array [1 2 3] [3 1]) + _ (ndarray/save-to-file filename nda) + load-nda (ndarray/load-from-file filename)] + (is (= [3 1] (-> load-nda shape mx-shape/->vec))) + (is (= [1.0 2.0 3.0] (->vec load-nda))))) + +(deftest test-get-context + (let [nda (ones [3 2]) + ctx (ndarray/context nda)] + (is (= "cpu" (ctx/device-type ctx))) + (is (= 0 (ctx/device-id ctx))))) + +(deftest test-equals + (let [nda1 (ndarray/array [1 2 3] [3 1]) + nda2 (ndarray/array [1 2 3] [3 1]) + nda3 (ndarray/array [1 2 3] [1 3]) + nda4 (ndarray/array [3 2 3] [3 1])] + (is (= nda1 nda2)) + (is (not= nda1 nda3)) + (is (not= nda1 nda4)))) + +(deftest test-slice + (let [nda (ndarray/array [1 2 3 4 5 6] [3 2])] + + (let [nda1 (ndarray/slice nda 1)] + (is (= [1 2] (-> nda1 shape mx-shape/->vec))) + (is (= [3.0 4.0] (->vec nda1)))) + + (let [nda2 (ndarray/slice nda 1 3)] + (is (= [2 2] (-> nda2 shape mx-shape/->vec))) + (is (= [3.0 4.0 5.0 6.0] (->vec nda2)))))) + +(deftest test-at + (let [nda (ndarray/array [1 2 3 4 5 6] [3 2]) + res (ndarray/at nda 1)] + (is (= [2] (-> res shape mx-shape/->vec))) + (is (= [3 4])))) + +(deftest test-reshape + (let [nda (ndarray/array [1 2 3 4 5 6] [3 2]) + nda1 (ndarray/reshape nda [2 3])] + (is (= [2 3] (-> nda1 shape mx-shape/->vec))) + (is (= [1.0 2.0 3.0 4.0 5.0 6.0] (->vec nda1))))) + +(deftest test-dispose-deps + (let [nda1 (ones [1 2]) + nda2 (ones [1 2]) + nda3 (ones [1 2]) + nda-with-deps (ndarray/+ nda3 (ndarray/+ nda1 nda2))] + (is (= 4 (ndarray/size (ndarray/dependencies nda-with-deps)))) + (is (contains? (-> (ndarray/dependencies nda-with-deps) keys set) (ndarray/handle nda1))) + (is (contains? (-> (ndarray/dependencies nda-with-deps) keys set) (ndarray/handle nda2))) + (is (contains? (-> (ndarray/dependencies nda-with-deps) keys set) (ndarray/handle nda3))) + (is (not (ndarray/is-disposed nda1))) + (is (not (ndarray/is-disposed nda2))) + (is (not (ndarray/is-disposed nda3))) + + (let [nda-no-deps (ndarray/dispose-deps nda-with-deps)] + (is (= 0 (ndarray/size (ndarray/dependencies nda-no-deps)))) + (is (ndarray/is-disposed nda1)) + (is (ndarray/is-disposed nda2)) + (is (ndarray/is-disposed nda3))))) + +(deftest test-dispose-deps-except + (let [nda1 (ones [1 2]) + nda2 (ones [1 2]) + nda3 (ones [1 2]) + nda1-2 (ndarray/+ nda1 nda2)] + + (let [res (-> (ndarray/+ nda1 nda2) + (ndarray/+ nda1-2) + (ndarray/+ nda3) + (ndarray/dispose-deps-except nda1-2))] + (is (= 3 (ndarray/size (ndarray/dependencies res)))) + (is (contains? (-> (ndarray/dependencies res) keys set) (ndarray/handle nda1))) + (is (contains? (-> (ndarray/dependencies res) keys set) (ndarray/handle nda2))) + (is (contains? (-> (ndarray/dependencies res) keys set) (ndarray/handle nda1-2))) + (is (not (ndarray/is-disposed nda1))) + (is (not (ndarray/is-disposed nda2))) + (is (ndarray/is-disposed nda3))))) + +(deftest test-serialize-deserialize + (let [nda (ndarray/* (ndarray/ones [1 2]) 3) + nda-bytes (ndarray/serialize nda) + nda-copy (ndarray/deserialize nda-bytes)] + (is (= nda nda-copy)))) + +(deftest test-dtype-int32 + (let [nda (ndarray/* (ones [1 2] {:dtype dtype/INT32}) 2)] + (is (= dtype/INT32 (ndarray/dtype nda))) + (is (= 8 (count (ndarray/->raw nda)))) + (is (= [2.0 2.0] (ndarray/->float-vec nda))) + (is (= [2 2] (ndarray/->int-vec nda))) + (is (= [2.0 2.0] (ndarray/->double-vec nda))) + (is (= [(byte 2) (byte 2)] (ndarray/->byte-vec nda))))) + +(deftest test-dtype-uint8 + (let [nda (ndarray/* (ones [1 2] {:dtype dtype/UINT8}) 2)] + (is (= dtype/UINT8 (ndarray/dtype nda))) + (is (= 2 (count (ndarray/->raw nda)))) + (is (= [2.0 2.0] (ndarray/->float-vec nda))) + (is (= [2 2] (ndarray/->int-vec nda))) + (is (= [2.0 2.0] (ndarray/->double-vec nda))) + (is (= [(byte 2) (byte 2)] (ndarray/->byte-vec nda))))) + +(deftest test-dtype-float64 + (let [nda (ndarray/* (ones [1 2] {:dtype dtype/FLOAT64}) 2)] + (is (= dtype/FLOAT64 (ndarray/dtype nda))) + (is (= 16 (count (ndarray/->raw nda)))) + (is (= [2.0 2.0] (ndarray/->float-vec nda))) + (is (= [2 2] (ndarray/->int-vec nda))) + (is (= [2.0 2.0] (ndarray/->double-vec nda))) + (is (= [(byte 2) (byte 2)] (ndarray/->byte-vec nda))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj new file mode 100644 index 000000000000..da60d1a18fb5 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj @@ -0,0 +1,599 @@ +(ns org.apache.clojure-mxnet.operator-test + (:require [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.random :as random] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.test-util :as test-util] + [clojure.test :refer :all]) + (:import (org.apache.mxnet NDArray))) + +(defn approx= [tolerance x y] + (test-util/approx= tolerance + (if (instance? NDArray x) (ndarray/->vec x) x) + (if (instance? NDArray y) (ndarray/->vec y) y))) + +(deftest test-elementwise-sum + (let [n 4 + shape-vec [5 5 3] + inputs (mapv (fn [i] (sym/variable (str "arg" i))) (range n)) + out (sym/element-wise-sum "esum" inputs) + arr (into [] (repeatedly n #(random/uniform -10 10 shape-vec))) + arr-grad (into [] (repeatedly n #(ndarray/empty shape-vec))) + exec (sym/bind out (context/default-context) arr arr-grad) + forward-output (-> exec (executor/forward) (executor/outputs) first) + forward-output-expected (reduce sym/+ arr)] + (approx= 1e-4 forward-output-expected forward-output) + + ;; backward + (let [out-grad (random/uniform -10 10 shape-vec) + _ (executor/backward exec out-grad)] + (doseq [grad arr-grad] + (is (= out-grad grad)))))) + + +(deftest test-concat + (let [shape-vecs[[2 2] [3 2]] + x (sym/variable "x") + y (sym/variable "y") + out (sym/concat "conc" nil [x y] {:dim 0}) + arr (mapv #(ndarray/empty %) shape-vecs) + arr-np (mapv #(ndarray/copy %) arr) + arr-grad (map #(ndarray/empty %) shape-vecs) + arg-names (sym/list-arguments out) + grad-map (zipmap arg-names arr-grad) + args (sym/list-arguments out) + [arg-shapes out-shapes aux-shapes] (sym/infer-shape out (zipmap args shape-vecs)) + out-shape-vec (first out-shapes) + out-grad (ndarray/empty out-shape-vec) + exec1 (sym/bind out (context/default-context) arr grad-map) + out1 (-> (executor/forward exec1) + (executor/outputs) + (first)) + ret (ndarray/concatenate arr)] + (is (= out1 ret)) + + ;;backward + (ndarray/copy-to out1 out-grad) + (ndarray/+= out-grad 1) + (executor/backward exec1 out-grad) + (let [grads arr-grad + np-grads arr-np] + (is (= grads (mapv #(ndarray/+ % 1) np-grads)))))) + + +(defn check-regression [model forward-fn backward-fn] + (let [shape-vec [3 1] + arr-data (random/uniform -1 1 shape-vec) + arr-label (random/uniform -1 1 [(first shape-vec)]) + arr-grad (ndarray/empty shape-vec) + exec1 (sym/bind model (context/default-context) [arr-data arr-label] {:data arr-grad}) + out1 (-> exec1 (executor/forward) (executor/outputs) first) + np-out (map forward-fn + (ndarray/->vec arr-data))] + (is (= shape-vec (-> out1 ndarray/shape mx-shape/->vec))) + (is (approx= 1e-6 np-out out1)) + + ;;backward + (executor/backward exec1) + (let [npout-back (mapv backward-fn + np-out (ndarray/->vec arr-label))] + (is (approx= 1e-6 npout-back arr-grad))))) + + +(deftest test-regression + (check-regression (sym/logistic-regression-output {:data (sym/variable "data") :label (sym/variable "label")}) + (fn [x] (/ 1.0 (+ 1.0 (Math/exp (* -1.0 x))))) + (fn [x y] (- x y))) + (check-regression (sym/linear-regression-output {:data (sym/variable "data") :label (sym/variable "label")}) + (fn [x] x) + (fn [x y] (- x y)))) + +(deftest swap-axes + (let [data (sym/variable "data") + shape-vec [2 3 4] + arr-data (ndarray/ones shape-vec)] + + (-> (ndarray/slice arr-data 0) + (ndarray/set 1)) + + (-> (ndarray/slice arr-data 1) + (ndarray/set 2)) + + ;; [[[ 1., 1., 1., 1.], + ;; [ 1., 1., 1., 1.], + ;; [ 1., 1., 1., 1.]], + ;; + ;; [[ 2., 2., 2., 2.], + ;; [ 2., 2., 2., 2.], + ;; [ 2., 2., 2., 2.]]] + + (let [swap0 (sym/swap-axis {:data data :dim1 0 :dim2 2}) + swap (sym/swap-axis {:data swap0 :dim1 1 :dim2 2}) + exec (sym/bind swap (context/default-context) arr-data) + out (-> (executor/forward exec) + (executor/outputs) + first)] + ;; After swapaxes(swapaxes(arrData, 0, 2), 1, 2) + ;; out should be + ;; [[[ 1., 1., 1.], + ;; [ 2., 2., 2.]], + ;; + ;; [[ 1., 1., 1.], + ;; [ 2., 2., 2.]], + ;; + ;; [[ 1., 1., 1.], + ;; [ 2., 2., 2.]], + ;; + ;; [[ 1., 1., 1.], + ;; [ 2., 2., 2.]]] + (= [4 2 3] (mx-shape/->vec (ndarray/shape out))) + (doseq [i (range 4)] + (let [val (ndarray/->vec (ndarray/slice out i))] + (is (approx= 1e-6 [1 1 1 2 2 2] val))))))) + + +(defn check-symbolic-forward [test-sym location expected tolerance] + (let [arr-data (mapv #(ndarray/copy %) location) + arr-grad (mapv #(ndarray/empty (mx-shape/->vec (ndarray/shape %))) location) + exec (sym/bind test-sym (context/default-context) arr-data arr-grad) + outputs (-> exec + (executor/forward) + (executor/outputs))] + (is (every? true? (map + (fn [x y] + #_(println "expected " (ndarray/->vec x)) + #_(println "actual " (ndarray/->vec y)) + (approx= tolerance x y)) + expected + outputs))))) + +(defn check-symbolic-backward [test-sym location grad expected tolerance] + (let [arr-data (mapv #(ndarray/copy %) location) + arr-grad (mapv #(ndarray/empty (mx-shape/->vec (ndarray/shape %))) location) + out-grad (mapv #(ndarray/copy %) grad) + exec (sym/bind test-sym (context/default-context) arr-data arr-grad) + exec (-> exec + (executor/forward) + (executor/backward out-grad)) + grad-arrays (executor/grad-arrays exec)] + (is (every? true? (map + (fn [x y] + #_(println "expected " (ndarray/->vec x)) + #_(println "actual " (ndarray/->vec y)) + (approx= tolerance x y)) + expected + grad-arrays))))) + + +(deftest test-scalar-op + (let [data (sym/variable "data") + shape-vec [3 4] + data-tmp (ndarray/* (ndarray/ones shape-vec) 5) + ;; (4x + 2)/2 + test (-> (sym/* data 4) + (sym/+ 2) + (sym/div 2)) + npout (-> (ndarray/* data-tmp 4) + (ndarray/+ 2) + (ndarray/div 2)) + ;; backward deriv is 2 + np-out-grad (ndarray/* (ndarray/ones shape-vec) 2)] + + (check-symbolic-forward test [data-tmp] [npout] 1e-5) + (check-symbolic-backward test [data-tmp] [(ndarray/ones shape-vec)] [np-out-grad] 1e-5))) + +(deftest ones + (let [ones (sym/ones [2 2]) + exec (sym/simple-bind ones (context/default-context))] + (is (approx= 1e-4 + [1 1 1 1] + (-> exec (executor/forward) (executor/outputs) (first)))))) + +(deftest zeros + (let [zeros (sym/zeros [2 2]) + exec (sym/simple-bind zeros (context/default-context))] + (is (approx= 1e-4 + [0 0 0 0] + (-> exec (executor/forward) (executor/outputs) (first)))))) + + +(deftest test-arange + (let [start 1 + stop 100 + step 2 + result (range start stop step) + x (sym/arange start stop {:step step}) + exec (sym/simple-bind x (context/default-context))] + (executor/forward exec) + (is (= 0 (count (executor/grad-arrays exec)))) + (is (approx= 1e-4 result (-> (executor/outputs exec) (first)))))) + +(deftest test-scalar-pow + (let [data (sym/variable "data") + shape-vec [1 1] + data-tmp (ndarray/* (ndarray/ones shape-vec) 3) + data-tmp-powered (ndarray/* (ndarray/ones shape-vec) 9) + test (sym/** data 2)] + (check-symbolic-forward test [data-tmp] [data-tmp-powered] 1e-5) + (check-symbolic-backward test [data-tmp] [(ndarray/ones shape-vec)] [(ndarray/* data-tmp 2)] 1e-5))) + +(deftest test-symbol-pow + (let [shape-vec [1 1] + data (sym/variable "data") + data-tmp (ndarray/* (ndarray/ones shape-vec) 2) + exp (sym/variable "exp") + exp-tmp (ndarray/* (ndarray/ones shape-vec) 3) + test (sym/** data exp)] + (check-symbolic-forward test [data-tmp exp-tmp] [(ndarray/* (ndarray/ones shape-vec) 8)] 1e-5) + (let [data-deriv (ndarray/* (ndarray/* (ndarray/ones shape-vec) 4) exp-tmp) + exp-deriv (ndarray/* (ndarray/* (ndarray/ones shape-vec) 8) + (ndarray/* (ndarray/ones shape-vec) (Math/log 2)))] + (check-symbolic-backward test + [data-tmp exp-tmp] + [(ndarray/ones shape-vec)] + [data-deriv exp-deriv] 1e-5)))) + +(deftest test-pow-fn + (let [shape-vec [3 4] + exp (sym/variable "exp") + y (sym/** exp 2) + x (ndarray/* (ndarray/ones shape-vec) 3)] + (check-symbolic-forward y [x] [(ndarray/* (ndarray/ones shape-vec) 9)] 1e-5) + ;; deriv is 2x + (check-symbolic-backward y + [x] + [(ndarray/ones shape-vec)] + [(-> (ndarray/ones shape-vec) + (ndarray/* 6))] + 1e-5))) + +(defn check-scalar-operation + [operator data-vec num expected] + (let [data (sym/variable "datas") + shape-vec [2 2] + test (operator data num) + exec (sym/simple-bind test (context/default-context) {"datas" shape-vec}) + _ (executor/set-arg exec "datas" data-vec) + output (-> (executor/forward exec) (executor/outputs) first)] + (is (approx= 1e-5 expected output)) + (is (= [0 0 0 0]) (-> (executor/backward exec (ndarray/ones shape-vec)) + (executor/get-grad "datas") + (ndarray/->vec))))) + +(defn check-symbol-operation + [operator data-vec-1 data-vec-2 expected] + (let [data (sym/variable "datas") + data2 (sym/variable "datas2") + shape-vec [2 2] + test (operator data data2) + exec (sym/simple-bind test (context/default-context) {"datas" shape-vec "datas2" shape-vec}) + _ (executor/set-arg exec "datas" data-vec-1) + _ (executor/set-arg exec "datas2" data-vec-2) + output (-> (executor/forward exec) (executor/outputs) first)] + (is (approx= 1e-5 expected output)) + _ (executor/backward exec (ndarray/ones shape-vec)) + (is (= [0 0 0 0]) (-> (executor/get-grad exec "datas") (ndarray/->vec))) + (is (= [0 0 0 0]) (-> (executor/get-grad exec "datas2") (ndarray/->vec))))) + +(defn check-scalar-2-operation + [operator data-vec expected] + (let [data (sym/variable "datas") + shape-vec [2 2] + test (operator data 2) + exec (sym/simple-bind test (context/default-context) {"datas" shape-vec}) + _ (executor/set-arg exec "datas" data-vec) + output (-> (executor/forward exec) (executor/outputs) first)] + (is (approx= 1e-5 expected output)) + (is (= [0 0 0 0]) (-> (executor/backward exec (ndarray/ones shape-vec)) + (executor/get-grad "datas") + (ndarray/->vec))))) + +(deftest test-scalar-equal + (check-scalar-operation sym/equal [1 2 3 4] 2 [0 1 0 0])) + +(deftest test-symbol-equal + (check-symbol-operation sym/equal [1 2 3 4] [1 3 2 6] [1 0 0 0])) + +(deftest test-scalar-equal-2 + (check-scalar-2-operation sym/equal [1 2 3 4] [0 1 0 0])) + +(deftest test-scalar-not-equal + (check-scalar-operation sym/not-equal [1 2 3 4] 2 [1 0 1 1])) + +(deftest test-symbol-not-equal + (check-symbol-operation sym/not-equal [1 2 3 4] [1 3 2 6] [0 1 1 1])) + +(deftest test-scalar-not-equal-2 + (check-scalar-2-operation sym/not-equal [1 2 3 4] [1 0 1 1])) + +(deftest test-scalar-greater + (check-scalar-operation sym/> [1 2 3 4] 2 [0 0 1 1])) + +(deftest test-symbol-greater + (check-symbol-operation sym/> [1 2 3 4] [1 3 2 6] [0 0 1 0])) + +(deftest test-scalar-greater-equal + (check-scalar-operation sym/>= [1 2 3 4] 2 [0 1 1 1])) + +(deftest test-symbol-greater-equal + (check-symbol-operation sym/>= [1 2 3 4] [1 3 2 6] [1 0 1 0])) + +(deftest test-scalar-lesser + (check-scalar-operation sym/< [1 2 3 4] 2 [1 0 0 0])) + +(deftest test-symbol-lesser + (check-symbol-operation sym/< [1 2 3 4] [1 3 2 6] [0 1 0 1])) + +(deftest test-scalar-lesser-equal + (check-scalar-operation sym/<= [1 2 3 4] 2 [1 1 0 0])) + +(deftest test-symbol-lesser-equal + (check-symbol-operation sym/<= [1 2 3 4] [1 3 2 6] [1 1 0 1])) + +(deftest test-embedding + (let [data (sym/variable "data") + embed (sym/embedding "embed" {:data data :input-dim 10 :output-dim 4})] + (println "Embedded symbol:" (sym/to-json embed)))) + +(deftest test-binary-duplicate-input + (let [data (sym/variable "data") + shape-vec [3 4] + data-tmp (ndarray/* (ndarray/ones shape-vec) 5) + arr-data (ndarray/copy data-tmp) + arr-grad (ndarray/* (ndarray/ones shape-vec) 3) + out-grad (ndarray/ones shape-vec) + square (sym/* data data) + exec-square (sym/bind square (context/default-context) arr-data arr-grad)] + (executor/forward exec-square) + (approx= 1e-6 (ndarray/* data-tmp data-tmp) (-> (executor/outputs exec-square) (first))) + (executor/backward exec-square out-grad) + (approx= 1e-6 (ndarray/* data-tmp 2) arr-grad))) + +(deftest test-sign + (let [data (sym/variable "data") + shape-vec [3 4] + data-tmp (ndarray/* (ndarray/ones shape-vec) 5) + arr-data (ndarray/copy data-tmp) + arr-grad (ndarray/* (ndarray/ones shape-vec) 3) + + test (sym/sign data) + exec-test (sym/bind test (context/default-context) [arr-data] [arr-grad])] + (is (test-util/approx= 1e-6 + (-> (ndarray/sign data-tmp) (ndarray/->vec)) + (-> exec-test (executor/forward) (executor/outputs) first (ndarray/->vec)))) + (executor/backward exec-test (ndarray/* (ndarray/ones shape-vec) 2)) + (is (approx= 1e-6 (ndarray/zeros shape-vec) arr-grad)))) + + +(deftest test-round-ceil-floor + (let [data (sym/variable "data") + shape-vec [3 4] + data-tmp (ndarray/* (ndarray/ones shape-vec) 5.543) + arr-data (ndarray/copy data-tmp) + arr-grad (ndarray/* (ndarray/ones shape-vec) 2) + + test (-> (sym/round data) + (sym/+ (sym/ceil data)) + (sym/+ (sym/floor data))) + exec-test (sym/bind test (context/default-context) [arr-data])] + (is (approx= 1e-6 + (-> (ndarray/round data-tmp) + (ndarray/+ (ndarray/ceil data-tmp)) + (ndarray/+ (ndarray/floor data-tmp))) + (-> (executor/forward exec-test) (executor/outputs) (first)))))) + + +(deftest test-rsqrt-cos-sin + (let [data (sym/variable "data") + shape-vec [3 4] + data-tmp (ndarray/* (ndarray/ones shape-vec) 5) + arr-data (ndarray/copy data-tmp) + arr-grad (ndarray/* (ndarray/ones shape-vec) 3) + + test (-> (sym/rsqrt data) + (sym/+ (sym/cos data)) + (sym/+ (sym/sin data))) + exec-test (sym/bind test (context/default-context) [arr-data])] + (is (approx= 1e-6 + (-> (ndarray/rsqrt data-tmp) + (ndarray/+ (ndarray/cos data-tmp)) + (ndarray/+ (ndarray/sin data-tmp))) + (-> (executor/forward exec-test) (executor/outputs) (first)))))) + +(deftest test-maximum + (let [data1 (sym/variable "data") + data2 (sym/variable "data") + shape-vec [3 4] + data-tmp1 (random/uniform 0 100 shape-vec) + data-tmp2 (random/uniform 0 100 shape-vec) + + arr-data1 (ndarray/copy data-tmp1) + arr-data2 (ndarray/copy data-tmp2) + + test (sym/max data1 data2) + exec-test (sym/bind test (context/default-context) [arr-data1 arr-data2]) + out (-> (executor/forward exec-test) (executor/outputs) (first))] + (is (approx= 1e-6 + (mapv max (ndarray/->vec data-tmp1) (ndarray/->vec data-tmp2)) + out)))) + +(deftest test-minimun + (let [data1 (sym/variable "data") + data2 (sym/variable "data") + shape-vec [3 4] + data-tmp1 (random/uniform 0 100 shape-vec) + data-tmp2 (random/uniform 0 100 shape-vec) + + arr-data1 (ndarray/copy data-tmp1) + arr-data2 (ndarray/copy data-tmp2) + + test (sym/min data1 data2) + exec-test (sym/bind test (context/default-context) [arr-data1 arr-data2]) + out (-> (executor/forward exec-test) (executor/outputs) (first))] + (is (approx= 1e-6 + (mapv min (ndarray/->vec data-tmp1) (ndarray/->vec data-tmp2)) + out)))) + +(deftest test-transpose + (let [data (sym/variable "data") + test (sym/transpose data) + shape-vec [3 4] + ctx (context/default-context) + arr-data (random/uniform 0 100 shape-vec ctx) + trans (ndarray/transpose (ndarray/copy arr-data)) + exec-test (sym/bind test ctx {"data" arr-data}) + out (-> (executor/forward exec-test) + (executor/outputs) + (first))] + (is (approx= 1e-6 trans out)) + (is (= [4 3] (mx-shape/->vec (ndarray/shape out)))))) + +(deftest test-smooth-l1-and-make-loss + (let [data (sym/variable "data") + smooth-l1 (sym/smooth-l1 {:data data :scalar 1.0}) + loss (sym/make-loss {:data smooth-l1}) + shape-vec [2 6] + ctx (context/default-context) + input (ndarray/array [-3.5 -2.5 -1.5 -0.5 -0.3 -0.1 + 0.1 0.3 0.5 1.5 2.5 3.5] shape-vec) + grad (ndarray/empty shape-vec) + arr-tmp [3.0 2.0 1.0 0.125 0.045 0.005 + 0.005 0.045 0.125 1.0 2.0 3.0] + grad-tmp [-1.0 -1.0 -1.0 -0.5 -0.3 -0.1 + 0.1 0.3 0.5 1.0 1.0 1.0] + exec-test (sym/bind loss ctx {:data input} {:data grad}) + out (-> (executor/forward exec-test) (executor/outputs) first)] + (is (approx= 1e-6 arr-tmp out)) + (executor/backward exec-test) + (is (approx= 1e-6 grad-tmp grad)))) + +(deftest test-maximum-minimum-scalar + (let [data (sym/variable "data") + shape-vec [3 4] + data-tmp (ndarray/* (ndarray/ones shape-vec) 2) + arr-data (ndarray/copy data-tmp) + test (-> (sym/max data 3) + (sym/+ (sym/max data 9)) + (sym/+ (sym/min data 5)) + (sym/+ (sym/min data 4))) + exec-test (sym/bind test (context/default-context) [arr-data])] + ;; 3 + 9 + 2 + 2 + (is (approx= 1e-6 (ndarray/* (ndarray/ones shape-vec) 16) (-> (executor/forward exec-test) (executor/outputs) (first)))))) + +(deftest test-abs + (let [data (sym/variable "data") + shape-vec [3 4] + data-tmp (ndarray/* (ndarray/ones shape-vec) 5) + arr-data (ndarray/copy data-tmp) + arr-grad (ndarray/* (ndarray/ones shape-vec) 3) + test (sym/abs data) + exec-test (sym/bind test (context/default-context) arr-data arr-grad)] + (is (approx= 1e-6 (ndarray/abs data-tmp) (-> (executor/forward exec-test) (executor/outputs) (first)))) + + (let [out-grad (ndarray/* (ndarray/ones shape-vec) 2) + npout-grad (ndarray/* out-grad (ndarray/sign data-tmp))] + (executor/backward exec-test out-grad) + (is (approx= 1e-6 npout-grad arr-grad))))) + + + ;; configure A: input --> conv --> deconv --> output. + ;; the convolution and deconvoluiton has similar parameter which ensure + ;; the input shape is the same as output, and the same weights between conv + ;; and deconv; + ;; If the input value of forward() and backwrad() is the same, then +;; the output value of them should also the same; + +(defn check-deconvolution-forward-backward [{:keys[input-shape-vec num-filter kernel stride pad]}] + (let [data (sym/variable "data") + conv (sym/convolution "conv" {:data data :kernel kernel :stride stride + :pad pad :num-filter num-filter :no-bias "true"}) + deconv (sym/deconvolution "deconv" {:data conv :kernel kernel :stride stride + :pad pad :num-filter num-filter :no-bias "true"} ) + arg-names (sym/list-arguments deconv) + arg-shape-vecs (first (sym/infer-shape deconv {:data input-shape-vec})) + input-data (random/uniform -5 5 input-shape-vec) + out-grad input-data + conv-weight (random/normal 0 1 [num-filter (second input-shape-vec) (first kernel) (last kernel)]) + args {:data input-data :conv-weight conv-weight :deconv-weight conv-weight} + args-grad (mapv #(ndarray/empty %) arg-shape-vecs) + exec (sym/bind deconv (context/default-context) args args-grad) + out (-> (executor/forward exec) (executor/outputs) first)] + (executor/backward exec out-grad) + (is (approx= 1e-3 (ndarray/->vec out) (ndarray/->vec (first args-grad)))))) + + +(deftest test-deconvolution-forward-and-backward + (check-deconvolution-forward-backward {:input-shape-vec [1 1 5 5] :num-filter 1 :kernel [3 3] :stride [1 1] :pad [1 1]}) + (check-deconvolution-forward-backward {:input-shape-vec [32 3 28 28] :num-filter 3 :kernel [3 3] :stride [1 1] :pad [1 1]}) + ;; commented out to make the tests fast + #_(check-deconvolution-forward-backward {:input-shape-vec [10 3 403 403] :num-filter 3 :kernel [7 7] :stride [5 5] :pad [2 2]}) + ) + +;; configure A: input --> conv --> output. +;; configure B: input --> deconv --> output +;; the convolution and deconvoluiton has similar parameter which ensure +;; the input shape is the same as output; +;; During backward(), if the input of A equals output of B, and the output +;; of A equals input of B, then the grad of weight should be the same; + +(defn check-deconvolution-gradient [{:keys [input-shape-vec num-filter pad]}] + (let [stride [1 1] + kernel [(inc (* 2 (first pad))) (inc (* 2 (second pad)))] + data-conv (sym/variable "data_conv") + conv (sym/convolution "conv" {:data data-conv :kernel kernel :stride stride + :pad pad :num-filter num-filter :no-bias "true"}) + data-deconv (sym/variable "data_deconv") + deconv (sym/deconvolution "deconv" {:data data-deconv :kernel kernel :stride stride + :pad pad :num-filter num-filter :no-bias true}) + conv-data (random/uniform -5 5 input-shape-vec) + conv-args {"data_conv" conv-data "conv_weight" (random/normal 0 1 [num-filter (second input-shape-vec) (first kernel) (second kernel)])} + conv-args-grad [(ndarray/zeros (-> conv-data (ndarray/shape) (ndarray/->vec))) + (ndarray/zeros [num-filter (second input-shape-vec) (first kernel) (second kernel)])] + exec-conv (sym/bind conv (context/default-context) conv-args conv-args-grad) + conv-out-grad (random/normal 0 2 (-> (executor/outputs exec-conv) (first) (ndarray/shape) (mx-shape/->vec)))] + (executor/forward exec-conv) + (executor/backward exec-conv conv-out-grad) + + (let [deconv-data conv-out-grad + deconv-args {"data_deconv" deconv-data "deconv_weight" (get conv-args "conv_weight")} + deconv-args-grad [(ndarray/zeros (-> deconv-data (ndarray/shape) (mx-shape/->vec))) + (ndarray/zeros [num-filter (second input-shape-vec) (first kernel) (second kernel)])] + exec-deconv (sym/bind deconv (context/default-context) deconv-args deconv-args-grad) + deconv-out-grad conv-data] + (executor/forward exec-deconv) + (executor/backward exec-deconv deconv-out-grad) + + (is (approx= 1e-4 (ndarray/->vec (second conv-args-grad)) (ndarray/->vec (second deconv-args-grad))))))) + +(deftest test-deconvolution-gradient + (check-deconvolution-gradient {:input-shape-vec [1 3 5 5] :num-filter 3 :pad [1 1]})) + +(defn check-nearest-up-sampling-with-shape [{:keys [shape-vecs scale root-scale]}] + (let [arr (zipmap (map #(str "arg_" %) (range 0 (count shape-vecs))) + (map #(random/uniform -10 10 %) shape-vecs)) + arr-grad (zipmap (map #(str "arg_" %) (range 0 (count shape-vecs))) + (map #(ndarray/zeros %) shape-vecs)) + up-args (mapv #(sym/variable (str "arg_" %)) (range 0 (count shape-vecs))) + up (sym/up-sampling "up-sampling" nil up-args {:sample-type "nearest" :scale root-scale}) + exec (sym/bind up (context/default-context) arr arr-grad)] + (executor/forward exec) + (executor/backward exec (executor/outputs exec)) + (doseq [k (range 0 (count shape-vecs))] + (let [k-name (str "arg_" k) + expected (->> (get arr k-name) (ndarray/->vec) (mapv #(* % (Math/pow root-scale 2) (Math/pow scale (* 2 k))))) + real (-> (get arr-grad k-name) (ndarray/->vec))] + (is (approx= 0.1 expected real)))))) + + +(deftest test-nearest-upsampling + (doall (for [root-scale (range 1 4) + scale (range 1 4) + num-shape (range 1 4) + base (range 1 4)] + (let [shape-vecs (mapv (fn [i] [1 3 (* base root-scale (int (Math/pow scale (- (dec num-shape) i)))) + (* base root-scale (int (Math/pow scale (- (dec num-shape) i))))]) + (range 0 num-shape))] + (check-nearest-up-sampling-with-shape {:shape-vecs shape-vecs :scale scale :root-scale root-scale}))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/optimizer_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/optimizer_test.clj new file mode 100644 index 000000000000..1bf7db450135 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/optimizer_test.clj @@ -0,0 +1,30 @@ +(ns org.apache.clojure-mxnet.optimizer-test + (:require [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.symbol :as sym] + [clojure.test :refer :all])) + +(defn test-optimizer [[opt-name optimizer-fn]] + (println "Testing optimizer - " opt-name) + (let [s (sym/variable "data") + s (sym/fully-connected {:data s :num-hidden 100}) + ;; single device + mod (m/module s {:data-names ["data"] :label-names nil})] + (-> mod + (m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]}) + (m/init-params) + (m/init-optimizer {:optimizer (optimizer-fn)}) + (m/update)))) + + +(deftest test-optimizer-update + (let [opts [["sgd" optimizer/sgd] + ["dcasgd" optimizer/dcasgd] + ["nag" optimizer/nag] + ["ada-delta" optimizer/ada-delta] + ["rms-prop" optimizer/rms-prop] + ["ada-grad" optimizer/ada-grad] + ["adam" optimizer/adam] + ["sgld" optimizer/sgld]]] + (doseq [opt opts] + (test-optimizer opt)))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj new file mode 100644 index 000000000000..cc1cc3b991a0 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj @@ -0,0 +1,37 @@ +(ns org.apache.clojure-mxnet.random-test + (:require [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.random :as random] + [clojure.test :refer :all])) + +(deftest test-uniform-on-cpu + (let [ctx (context/default-context)] + (let [[a b] [-10 10] + shape [100 100] + _ (random/seed 128) + un1 (random/uniform a b shape {:context ctx}) + _ (random/seed 128) + un2 (random/uniform a b shape {:context ctx})] + (is (= un1 un2)) + (is (< (Math/abs + (/ (/ (apply + (ndarray/->vec un1)) + (- (ndarray/size un1) (+ a b))) + 2.0)) + 0.1))))) + +(deftest test-normal-on-cpu + (let [[mu sigma] [10 2] + shape [100 100] + _ (random/seed 128) + ret1 (random/normal mu sigma shape) + _ (random/seed 128) + ret2 (random/normal mu sigma shape)] + (is (= ret1 ret2)) + + (let [array (ndarray/->vec ret1) + mean (/ (apply + array) (count array)) + devs (map #(* (- % mean) (- % mean)) array) + stddev (Math/sqrt (/ (apply + devs) (count array)))] + (is (< (Math/abs (- mean mu)) 0.1)) + (is (< (Math/abs (- stddev sigma)) 0.1))))) + diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/shape_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/shape_test.clj new file mode 100644 index 000000000000..2306ae97f965 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/shape_test.clj @@ -0,0 +1,11 @@ +(ns org.apache.clojure-mxnet.shape-test + (:require [org.apache.clojure-mxnet.shape :as mx-shape] + [clojure.test :refer :all])) + +(deftest test-to-string + (let [s (mx-shape/->shape [1 2 3])] + (is (= "(1,2,3)" (str s))))) + +(deftest test-equals + (is (= (mx-shape/->shape [1 2 3]) (mx-shape/->shape [1 2 3]))) + (is (not= (mx-shape/->shape [1 2]) (mx-shape/->shape [1 2 3])))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_test.clj new file mode 100644 index 000000000000..c4a407eda646 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_test.clj @@ -0,0 +1,47 @@ +(ns org.apache.clojure-mxnet.symbol-test + (:require [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.util :as util] + [clojure.test :refer :all])) + + +(deftest test-compose + (let [data (sym/variable "data") + net1 (sym/fully-connected "fc1" {:data data :num-hidden 10}) + net1 (sym/fully-connected "fc2" {:data net1 :num-hidden 100}) + + net2 (sym/fully-connected "fc3" {:num-hidden 10}) + net2 (sym/activation {:data net2 :act-type "relu"}) + net2 (sym/fully-connected "fc4" {:data net2 :num-hidden 20}) + + composed (sym/apply net2 "composed" {"fc3_data" net1}) + + multi-out (sym/group [composed net1])] + + (is (= ["data" "fc1_weight" "fc1_bias" "fc2_weight" "fc2_bias"] (sym/list-arguments net1))) + (println (sym/debug-str composed)) + (is (= 2 (count (sym/list-outputs multi-out)))))) + +(deftest test-symbol-internal + (let [data (sym/variable "data") + oldfc (sym/fully-connected "fc1" {:data data :num-hidden 10}) + net1 (sym/fully-connected "fc2" {:data oldfc :num-hidden 100})] + (is (= ["data" "fc1_weight" "fc1_bias" "fc2_weight" "fc2_bias"] (sym/list-arguments net1))) + (= (sym/list-arguments oldfc) (-> (sym/get-internals net1) + (sym/get "fc1_output") + (sym/list-arguments))))) + +(deftest test-infer-type + (let [data (sym/variable "data") + f32data (sym/cast {:data data :dtype "float32"}) + fc1 (sym/fully-connected "fc1" {:data f32data :num-hidden 128}) + mlp (sym/softmax-output "softmax" {:data fc1}) + [arg out aux] (sym/infer-type mlp {:data dtype/FLOAT64})] + (is (= [dtype/FLOAT64 dtype/FLOAT32 dtype/FLOAT32 dtype/FLOAT32] (util/buffer->vec arg))) + (is (= [dtype/FLOAT32 (util/buffer->vec out)])) + (is (= [] (util/buffer->vec aux))))) + +(deftest test-copy + (let [data (sym/variable "data") + data2 (sym/clone data)] + (is (= (sym/to-json data) (sym/to-json data2))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj new file mode 100644 index 000000000000..77f07aded8dc --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj @@ -0,0 +1,10 @@ +(ns org.apache.clojure-mxnet.test-util + (:require [clojure.test :as t])) + +(defn approx= [tolerance x y] + (if (and (number? x) (number? y)) + (let [diff (Math/abs (- x y))] + (< diff tolerance)) + (reduce (fn [x y] (and x y)) + (map #(approx= tolerance %1 %2) x y)))) + diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj new file mode 100644 index 000000000000..d1ba40e72008 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj @@ -0,0 +1,168 @@ +(ns org.apache.clojure-mxnet.util-test + (:require [clojure.test :refer :all] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.symbol :as sym] + [clojure.spec.alpha :as s]) + (:import (org.apache.mxnet Shape NDArrayFuncReturn NDArray) + (scala.collection Map Set) + (scala.collection.mutable ArrayBuffer) + (scala.collection.immutable List IndexedSeq ListMap Vector) + (scala Option Tuple1 Tuple2 Tuple3))) + +(deftest test-empty-list + (let [x (util/empty-list)] + (is (instance? List x)) + (is (true? (.isEmpty x))))) + +(deftest test-empty-map + (let [x (util/empty-map)] + (is (instance? Map x)) + (is (true? (.isEmpty x))))) + +(deftest test-indexed-seq + (let [x (util/empty-indexed-seq)] + (is (instance? IndexedSeq x)) + (is (true? (.isEmpty x))))) + +(deftest test-empty-list-map + (let [x (util/empty-list-map)] + (is (instance? ListMap x)) + (is (true? (.isEmpty x))))) + +(deftest test->option + (let [x (util/->option 1)] + (is (instance? Option x)) + (is (= 1 (.get x))))) + +(deftest test-option->value + (is (= 2 (-> (util/->option 2) + (util/option->value))))) + +(deftest test-keyword->snake-case + (is (= [:foo-bar :foo2 :bar-bar]) + (util/keyword->snake-case [:foo_bar :foo2 :bar-bar]))) + +(deftest test-convert-tuple + (is (instance? Tuple1 (util/convert-tuple [1]))) + (is (instance? Tuple2 (util/convert-tuple [1 2]))) + (is (instance? Tuple3 (util/convert-tuple [1 2 3])))) + +(deftest test-convert-by-shape + (let [x (util/convert-by-shape {:a [100] :b "hi"})] + (is (instance? Shape (:a x))) + (is (= "hi" (:b x))))) + +(deftest tuple-convert-by-param-name + (let [x (util/tuple-convert-by-param-name {:foo [100] :kernel [3 3] :bar "hi"})] + (is (= "(3,3)" (:kernel x))) + (is (= [100] (:foo x))) + (is (= "hi" (:bar x))))) + +(deftest test-io-convert-by-param-name + (let [x (util/io-convert-by-param-name {:input-shape [10 10] :freeze? true :foo 1})] + (is (= "(10,10)" (:input-shape x))) + (is (= "True" (:freeze? x))) + (is (= "1" (:foo x))))) + +(deftest test-convert-map + (let [x (util/convert-map {:a [10] :b 1 :foo-bar 2})] + (is (instance? Map x)) + (is (= "Set(a, b, foo_bar)" (-> x (.keys) str))))) + +(deftest test-convert-vector + (let [x (util/convert-vector [1 2 3])] + (is (instance? List x)) + (is (= "List(1, 2, 3)" (str x))))) + +(deftest test-vec->set + (let [x (util/vec->set [1 2 3])] + (is (instance? Set x)) + (is (= "Set(1, 2, 3)" (str x))))) + +(deftest test-vec->indexed-seq + (let [x (util/vec->indexed-seq [1 2 3])] + (is (instance? Vector x)) + (is (= "Vector(1, 2, 3)" (str x))))) + +(deftest test-scala-function + (let [s-fn (util/scala-fn (fn [x] (+ x 2)))] + (is (= 4 (util/apply-scala-fn s-fn 2))))) + +(deftest test-coerce-param + (is (instance? Map (util/coerce-param {:x 1} #{"scala.collection.immutable.Map"}))) + (is (map? (util/coerce-param {:x 1} #{"float"}))) + + (is (float? (util/coerce-param 1 #{"float"}))) + + (is (instance? List (util/coerce-param (ndarray/ones [3]) #{"scala.collection.Seq"}))) + (is (instance? List (util/coerce-param (sym/variable "a") #{"scala.collection.Seq"}))) + (is (instance? List (util/coerce-param [1 2] #{"scala.collection.Seq"}))) + (is (instance? List (util/coerce-param [] #{"scala.collection.Seq"}))) + + (is (= "[I" (->> (util/coerce-param [1 2] #{"int<>"}) str (take 2) (apply str)))) + (is (= "[F" (->> (util/coerce-param [1 2] #{"float<>"}) str (take 2) (apply str)))) + (is (= "[L" (->> (util/coerce-param [1 2] #{"java.lang.String<>"}) str (take 2) (apply str)))) + + (is (= 1 (util/coerce-param 1 #{"unknown"})))) + +(deftest test-nil-or-coerce-param + (is (instance? Map (util/nil-or-coerce-param {:x 1} #{"scala.collection.immutable.Map"}))) + (is (nil? (util/coerce-param nil #{"scala.collection.immutable.Map"})))) + +(deftest test-scala-map->map + (is (= {"a" 1} (-> (util/convert-map {:a 1}) + (util/scala-map->map))))) + +(deftest test-buffer->vec + (is (= [] (util/buffer->vec (ArrayBuffer.))))) + +(deftest test-scala-vector->vec + (is (= [1 2 3] (util/scala-vector->vec + (util/vec->indexed-seq [1 2 3]))))) + +(deftest test-scala-iterator->seq + (is (= [1 2 3] (-> (util/vec->indexed-seq [1 2 3]) + (.iterator) + (util/scala-iterator->seq))))) + +(deftest test-tuple->vec + (is (= [1 2] (-> (util/convert-tuple [1 2]) + (util/tuple->vec))))) + +(deftest test-coerce-return + (is (= [] (util/coerce-return (ArrayBuffer.)))) + (is (= [1 2 3] (util/coerce-return (util/vec->indexed-seq [1 2 3])))) + (is (instance? NDArray + (util/coerce-return + (new NDArrayFuncReturn (into-array [(ndarray/zeros [3])]))))) + (is (= {"x" 1} (util/coerce-return + (util/convert-map {:x 1})))) + (is (= [1 2] (util/coerce-return + (util/convert-tuple [1 2])))) + (is (= [1 2 3] (util/coerce-return + (util/convert-tuple [1 2 3])))) + (is (= "foo" (util/coerce-return "foo")))) + +(deftest test-translate-keyword-shape + (let [[name shape] (util/translate-keyword-shape [:foo-a [5]])] + (is (= name "foo_a")) + (is (instance? Shape shape)) + (is (= "(5)" (str shape))))) + +(deftest test-map->tuple + (let [x (util/map->tuple {:foo-a [5]})] + (is (instance? Tuple2 (first x))) + (is (= "(foo_a,(5))" (str (first x)))))) + +(deftest test-list-map + (let [x (util/list-map {:x 1 :y 2})] + (is (instance? ListMap x)) + (is (= "Map(x -> 1, y -> 2)" (str x))))) + +(s/def ::x string?) + +(deftest test-validate + (is (nil? (util/validate! string? "foo" "Not a string!"))) + (is (thrown-with-msg? Exception #"Not a string!" (util/validate! ::x 1 "Not a string!")))) diff --git a/contrib/clojure-package/testing.md b/contrib/clojure-package/testing.md new file mode 100644 index 000000000000..8f87c2a76e81 --- /dev/null +++ b/contrib/clojure-package/testing.md @@ -0,0 +1,23 @@ +## Help with Testing + +If you want to give the repo a spin and help make it stable and ready for prime time that would be awesome. + +Here is what you can do. + +* Clone the project +* Edit the project.clj file and uncomment the line that is for your system (OSX, Linux CPU, or Linux GPU) +* Run `lein deps` (this might take a bit - the jars are big!) +* Run `lein test` - there should be no errors. The tests are all cpu +* Run `lein install` to install the clojure-package locally +* Go to the module examples `cd examples/module` +* Either run `lein run` or `lein run :gpu` + +If you find any problems, please log on issue. + +Thanks! + +## Want to explore more? + +The examples/tutorial is a good REPL walkthrough +The examples/pre-trained-modules is nice too +The examples/gan is just plain fun :) diff --git a/contrib/clojure-package/vi ci-test.sh b/contrib/clojure-package/vi ci-test.sh new file mode 100644 index 000000000000..dc12f1389688 --- /dev/null +++ b/contrib/clojure-package/vi ci-test.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +set -evx + +cd contrib/clojure-package +lein test