8wDlpd.png
8wDFp9.png
8wDEOx.png
8wDMfH.png
8wDKte.png

如何利用小型神经网络高效地做出大量预测?

Beni Cherniavsky-Paskin 2月前

26 0

我需要使用小型神经网络(100-150 个参数)进行大量预测。我在 TensorFlow 中实现了它,但遇到了效率问题。这是伪代码:for my_dense_netowrk,my_lstm_n...

我需要用小型神经网络(100-150 个参数)进行大量预测。我在 TensorFlow 中实现了它,但遇到了效率问题。以下是伪代码:

for my_dense_netowrk,my_lstm_netowrk in networks_list
    my_dense_netowrk.paramters = 100
    my_lstm_netowrk.paramters = 150
    for images in data[:60]:
        @tf.function
        def tf_wrapper(images, state):
           model_data = meta_model(images)
           data_prepared = image_preparation(model_data)

           results = my_dense_netowrk(data_prepared)
           results.shape = (19000,1,1)
           
           better_results, state = my_lstm_netowrk(results, state)
           return better_results, state
           
        better_results, state = tf_wrapper(images, state)

my_dense_netowrk_n2.paramters = 100
my_lstm_netowrk_n2.paramters = 100

and continue...
  1. 我使用 tensorflow 数据管道 api,实际上所有必需的数据(数据变量)都可以分配在我的内存中。
  2. 在构建神经网络时,我没有指定批处理大小,而是将大量数据(批处理大小为 19000)插入神经网络以并行化所有内容。即使是 LSTM 也不会受到序列处理的限制,因为它必须一次处理 19000 个输入。但是,当我将神经网络参数增加 10 倍(我不需要)时,我的代码几乎没有注意到它认为批处理大小相当大。
  3. @tf.fcuntion 稍微加快一切速度。
  4. 我尝试了性能分析,但由于操作太多,未能找到瓶颈。我发现内核启动占用了一半的时间,因为通常 TensorFlow 预计这个过程会花费大量时间,所以我猜它没有针对此类任务进行优化,因为当我将循环从 60 增加到 6000 时,每次循环的效率都会提高 10 倍!看来准备工作需要时间。
  5. image_preparation() 函数仅使用 tf ops(例如 reshape、stack、tile),而且我无法提前准备数据。
  6. 我使用带有 M3 Max 芯片的 macOS,使用 GPU 或 CPU 没有区别。我尝试了 python 3.8、3.9、3.10、3.11、3.12。

因此,似乎 tensorflow 并没有被我的模型所限制,这很奇怪,而且互联网上关于如何有效地从小型模型中获得许多预测的讨论并不多,每个人都使用这样的库来处理大型 NN。虽然我认为我的管道应该会从中受益,因为我使用了大批量,但 gpu 根本没有帮助。所以我真的很难找到一个好的解决方案,想征求一下建议。也许有更好的 ml 框架可以解决我的问题(PyTorch、Jax,也许还有其他的?)或者我只是不擅长分析?或者我应该尝试用汇编语言构建自己的内核吗?我不知道

帖子版权声明 1、本帖标题:如何利用小型神经网络高效地做出大量预测?
    本站网址:http://xjnalaquan.com/
2、本网站的资源部分来源于网络,如有侵权,请联系站长进行删除处理。
3、会员发帖仅代表会员个人观点,并不代表本站赞同其观点和对其真实性负责。
4、本站一律禁止以任何方式发布或转载任何违法的相关信息,访客发现请向站长举报
5、站长邮箱:yeweds@126.com 除非注明,本帖由Beni Cherniavsky-Paskin在本站《tensorflow》版块原创发布, 转载请注明出处!
最新回复 (0)
  • 我目前的项目包括构建一个 Angular 前端来捕获 POST 请求 API,该 API 既从我的 Amazon 发送数据(例如“image_id”、“modules”等列表)

    我当前的项目包括构建一个 Angular 前端来捕获 POST 请求 API,该 API 既从我的 Amazon DynamoDB 表发送数据(例如 \'image_id\'、\'modules\' 等列表),也从 Angular 验证传入的输入(也是 image_id 等列表),具体取决于有效负载(后端 AWS lambda 函数是 Java,我的具体任务是构建数据检索)。我正处于项目阶段,需要将我的 lambda 函数连接到 DynamoDB 并获取特定数据(我想要获取的第一个数据是 \'image_id\',包含在 \'container_master\' 表中)。表的每个项目中都有一个 \'image_id\' 键,我现在希望能够收集所有 id。但是,我遇到了障碍。问题是,当“image_id”数组显示在网站的控制台上时,它是空的。我认为问题是由于 lambda 函数对数据库中扫描项目的处理造成的,因为任何其他硬编码并发送到我的网站的数组都可以很好地呈现。

    HomeScreen_Entity.java:

    @DynamoDBTable(tableName = "container_master")
    public class HomeScreen_Entity {
        private String image_id;
        AmazonDynamoDB dynamoDBClient;
        private Regions REGION = Regions.US_WEST_2;
        DynamoDBMapper mapper;
        public HomeScreen_Entity() {
            dynamoDBClient = AmazonDynamoDBClientBuilder.standard().withRegion(REGION).build();
            mapper = new DynamoDBMapper(dynamoDBClient);
        }
        
        @DynamoDBHashKey(attributeName = "image_id")
        public String getImage_id() {
            return image_id;
        }
        public void setImage_id(String image_id) {
            this.image_id = image_id;
        }
        public List<String> getAllImageIds() { //HomeScreen_Entity.java
            
            ScanRequest scanRequest = new ScanRequest()
                    .withTableName("container_master")
                    .withAttributesToGet("image_id");
            ScanResult result = dynamoDBClient.scan(scanRequest);
            List<String> imageIds = new ArrayList<>();
            for (Map<String, com.amazonaws.services.dynamodbv2.model.AttributeValue> item : result.getItems()) {
                if (item.containsKey("image_id")) {
                    imageIds.add(item.get("image_id").getS());
                }
            }
        }
        
    }
    

    上面的代码处理DynamoDB表的扫描。ECS_DAO.java:

    public List<HomeScreen_Entity> fetchData(Payload_Entity pe) { //ECS_DAO.java
            List<HomeScreen_Entity> entities = new ArrayList<>();
            try {
                HomeScreen_Entity homeScreenEntity = new HomeScreen_Entity();
                List<String> imageIds = homeScreenEntity.getAllImageIds();
                for (String id : imageIds) {
                    HomeScreen_Entity entity = new HomeScreen_Entity();
                    entity.setImage_id(id);
                    entities.add(entity);
                    }
                } catch (Exception ex) {
                    ex.printStackTrace();
                    System.out.println("No records found");
                }
            
            return entities;
        }
    

    上面的代码将处理所有常规数据(目前,我专注于图像 ID。ECS_BO.java:

    public List<String> retrieveDataFromDynamoDB(Payload_Entity pe) { //ECS_BO.java
            ECS_DAO db = new ECS_DAO();
            List<HomeScreen_Entity> data = db.fetchData(pe);
            List<String> imageIds = data.stream()
                    .map(HomeScreen_Entity::getImage_id)
                    .collect(Collectors.toList());
            return imageIds;
        }
    

    上述代码将 HomeScreen_Entity 对象转换为 List以便稍后可以将其转换为 JSON 并呈现。MainController.java(只是为了好用):

    case RETRIEVE_DATA: //MainController.java. payload change to "retrieve" instead of "execute"/"running"
                        ECS_BO ecsbo = new ECS_BO();
                        body = ecsbo.retrieveDataFromDynamoDB(pe);
                        return buildBodyApiResponse(context, 200, body);
    ...
    

    我很难弄清楚为什么 \'body\' 返回一个空数组。我猜是数据扫描处理存在问题。免责声明:DynamoDB 表中应该有几个项目。另一个免责声明:image_id 是分区键。

  • 在您的 for 循环中添加一个日志,以便您可以查看是否从 DynamoDB 返回项目。

返回
作者最近主题: