8wDlpd.png
8wDFp9.png
8wDEOx.png
8wDMfH.png
8wDKte.png

如果使用 tf.lite.interpreter,循环中会发生内存泄漏

keithjgrant 2月前

33 0

我正在开发一个 FastAPI 项目,该项目接受带有模型名称和数据的请求。这些数据经过标准路径:(预处理、模型处理、后处理)。解释器...

我正在开发一个 FastAPI 项目,该项目接受带有模型名称和数据的请求。这些数据经过标准路径:(预处理、模型处理、后处理)。每个模型的解释器都存储在字典中,以便再次调用该模型。

由于这些型号数量较多(超过 500 种),占用的 RAM 量接近 3 GB。

我想尝试减少 RAM 消耗,并尝试删除 tf.lite.interpreter 保存,但这导致了内存泄漏。据我了解,这是因为 tf.lite.interpreter 每次创建实例时都会在内存中存储一​​些数据。

我想知道是否可以在处理后删除 tf.lite.interpreter 创建的所有数据。以减少 RAM 消耗。还是假设 tf.lite.interpreter 是为模型创建的,并在运行过程中不断存储?有没有办法在使用模型后释放内存?我的方法有意义吗?

提前感谢您的帮助。我刚刚开始使用 ML。

一开始我用memory_profiler分析了内存。它显示内存是在使用tf.lite.Interpreter()的行上分配的

之后我创建了一个小脚本用于测试。

from time import perf_counter, sleep
import tensorflow as tf
import psutil
import logging
from logging.handlers import RotatingFileHandler


PATH = ...


rotating_handler = RotatingFileHandler(
    filename='logs/memory_leak.log', 
    mode='a+', 
    maxBytes=int(20e6),
    backupCount=10, 
    encoding='utf-8'
)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
rotating_handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.setLevel(logging.WARNING)
root_logger.addHandler(rotating_handler)

def create_interpreter(path:str):
    with tf.device('/CPU:0'):
        timer_start: float = perf_counter()
        interpreter = tf.lite.Interpreter(model_path=path)
        interpreter.allocate_tensors()
        create_time = perf_counter() - timer_start
        print(f'ID interpreter: {id(interpreter)}')
        print(f'Time: {create_time}')
        process = psutil.Process()
        mem_info = process.memory_info()
        logging.warning(f'Memory usage: RSS={mem_info.rss / 1024 ** 2:.2f} MB, VMS={mem_info.vms / 1024 ** 2:.2f} MB')
        return interpreter

while True:
    some_interpreter = create_interpreter(path=PATH)
    some_interpreter.allocate_tensors()
    sleep(5)
    del some_interpreter

我知道del删除唯一的链接但是我认为之后python GC会清除内存。根据该脚本的日志,python进程消耗的内存正在慢慢增加。

帖子版权声明 1、本帖标题:如果使用 tf.lite.interpreter,循环中会发生内存泄漏
    本站网址:http://xjnalaquan.com/
2、本网站的资源部分来源于网络,如有侵权,请联系站长进行删除处理。
3、会员发帖仅代表会员个人观点,并不代表本站赞同其观点和对其真实性负责。
4、本站一律禁止以任何方式发布或转载任何违法的相关信息,访客发现请向站长举报
5、站长邮箱:yeweds@126.com 除非注明,本帖由keithjgrant在本站《tensorflow》版块原创发布, 转载请注明出处!
最新回复 (0)
  • 使用给定的平面列表:let list = [ { key: 1, parent: null, }, { key: 2, parent: 1, }, { key: 3, parent: null, }, { ...

    使用给定的平面 列表

    let list = [
        {
            key: 1,
            parent: null,
        },
        {
            key: 2,
            parent: 1,
        },
        {
            key: 3,
            parent: null,
        },
        {
            key: 4,
            parent: 1,
        },
        {
            key: 5,
            parent: 2,
        }
    ]
    

    如何创建像下面这样的嵌套对象?

    let nest = {
        children: [
            {
                key: 1,
                children: [
                    {
                        key: 2,
                        children: [
                            {
                                key: 5,
                                children: []
                            }
                        ]
                    },
                    {
                        key: 4,
                        children: []
                    }
                ]
            },
            {
                key: 3,
                children: []
            }
        ]
    }
    

    我不知道该如何解决这个问题。我想到的解决方案是必须反复遍历列表,以检查对象的父对象是否为空(在这种情况下,它将被分配为顶级对象),或者对象的父对象已经存在(在这种情况下,我们获取父对象的路径,并将子对象分配给该父对象)。

    PSI 认为这不是以下任何内容的重复

    • 将检查平面物体中是否存在键。
    • 并没有显示任何可以返回路径的内容。
  • 为了构建树,您可以使用单循环方法,不仅使用给定的方法 key 来构建节点,而且 parent 还使用给定的方法来构建节点,其中依赖性显然是存在的。

    它使用一个对象,其中所有键都用作引用,例如

    {
        1: {
            key: 1,
            children: [
                {
                    /**id:4**/
                    key: 2,
                    children: [
                        {
                            /**id:6**/
                            key: 5,
                            children: []
                        }
                    ]
                },
                {
                    /**id:8**/
                    key: 4,
                    children: []
                }
            ]
        },
        2: /**ref:4**/,
        3: {
            key: 3,
            children: []
        },
        4: /**ref:8**/,
        5: /**ref:6**/
    }
    

    除了单循环之外,它的主要优点是,它可以处理未排序的数据,因为它可以一起使用结构 keys parent 信息。

    var list = [{ key: 1, parent: null, }, { key: 2, parent: 1, }, { key: 3, parent: null, }, { key: 4, parent: 1, }, { key: 5, parent: 2, }],
        tree = function (data, root) {
            var r = [], o = {};
            data.forEach(function (a) {
                var temp = { key: a.key };
                temp.children = o[a.key] && o[a.key].children || [];
                o[a.key] = temp;
                if (a.parent === root) {
                    r.push(temp);
                } else {
                    o[a.parent] = o[a.parent] || {};
                    o[a.parent].children = o[a.parent].children || [];
                    o[a.parent].children.push(temp);
                }
            });
            return r;
        }(list, null),
        nest = { children: tree };
    
    console.log(nest);
    .as-console-wrapper { max-height: 100% !important; top: 0; }
  • 我有一个.keras 文件(我将用它来加载我的模型)和一个巨大的数据集 X(~500 万行),我想调用 model.predict(X) 来检索我的预测。有没有办法利用多进程......

    我有一个.keras 文件(我将用它来加载我的模型)和一个巨大的数据集 X(~500 万行),我想调用 model.predict(X) 来检索我的预测。

    有没有办法利用多处理/线程来加快速度?我考虑将 X 拆分成不同的块,然后对这些子集(并行)调用 model.predict()。但是 Python 的多处理需要对对象进行 pickle,因此我无法传递 keras.Model 作为参数,这意味着我必须为每个线程/工作者/实例加载一个新模型,这非常慢。

    对此的总体策略是什么?

  • 我正在构建一个聊天机器人应用程序,该应用程序连接到 SQL 服务器,抓取自定义数据表,将它们转换为 pandas 数据框,然后将它们与 Langchain 代理一起使用。页面加载/签名时...

    我正在构建一个聊天机器人应用程序,该应用程序连接到 SQL 服务器,抓取自定义数据表,将它们转换为 pandas 数据框,然后将它们与 Langchain 代理一起使用。在页面加载/登录时,用户的电子邮件用于确定从同一数据库的不同 SQL 表中提取的敏感信息的访问级别。这些表存储在数据框中,并与我的 llm 一起传递给 pandas Langchain 代理。每当用户提出问题时,它都会调用代理,然后代理会做出回应。在我的后端,我还会检查用户是否想要某种形式的可视化,并使用 plotly、数据框和我让代理创建的自定义代码来创建它。

    在本地,当我为一个用户及其电子邮件进行测试时,这种方法确实很有效。但是,我不确定当许多用户同时使用该网站时如何扩展它。我正在努力弄清楚如何确保数据帧、代理/llm 的内存和代理本身可以为每个人初始化并在众多请求中使用。

    目前,这是我对网络应用程序启动的实现:

    from flask import Flask, g, sessions, request, jsonify, render_template, session
    from flask_session import Session
    from sqlalchemy import create_engine, text
    import pandas as pd
    import os
    import urllib.parse
    from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
    from langchain_openai import AzureChatOpenAI
    from langchain import prompts
    from langchain.agents.agent_types import AgentType
    from langchain.memory import ConversationBufferWindowMemory
    import seaborn as sns
    import plotly.express as px
    import re
    import plotly.io as pio
    import logging
    from sqlalchemy.engine import URL
    from flask_sqlalchemy import SQLAlchemy
    
    app = Flask(__name__, template_folder='templates')
    secret_key = os.environ['FLASK_SECRET_KEY']
    app.secret_key = secret_key
    
    user = os.environ["USER_NAME"]
    password = os.environ["PASSWORD"]
    hostName = os.environ["HOSTNAME"]
    port = os.environ["SQLPORT"]
    db = os.environ["DATABASE"]
    db_context = os.environ["CONTEXT"]
    
    params = urllib.parse.quote_plus(
        "Driver={ODBC Driver 17 for SQL Server};"
        f"Server=tcp:{hostName},1433;"
        f"Database={db};"
        f"Uid={user};"
        f"Pwd={password};"
        "Encrypt=yes;"
        "TrustServerCertificate=yes;" #yes if running locally, no for production
        "Connection Timeout=240;"
    )
    
    connection_string = f"mssql+pyodbc:///?odbc_connect={params}"
    app.config['SQLALCHEMY_DATABASE_URI'] = connection_string
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
    db = SQLAlchemy(app)
    
    @app.route('/')
    def home():
        global dataframes, memory, chat_prompt, context, userName, userImage, agent
        chat_prompt = chat_prompt_init()
        email = request.args.get('email')
    
        logging.info('email got')
        if not email:
            return render_template('no_email.html')
    
        dataframes = load_data_frames1(email)
    
        context = dataframes.get("context")
        contractDetails = dataframes.get("contract_details")
        projectDetails = dataframes.get("project_details")
        storeDetails = dataframes.get("store_details")
        userName = dataframes.get("name")
        userImage = dataframes.get("profile_pic")
        memory = ConversationBufferWindowMemory()
    
        client = AzureChatOpenAI(
            api_key=os.environ["AZURE_OPENAI_KEY"],  
            api_version="2023-12-01-preview",
            azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
            deployment_name=os.environ["DEPLOYMENT_NAME"],
            temperature=0
        )
    
        agent = create_pandas_dataframe_agent(llm=client, df=[storeDetails, contractDetails, projectDetails], verbose=True, agent_type=AgentType.OPENAI_FUNCTIONS,  return_intermediate_steps=True)
    
        return render_template('Index.html')   
    

    当用户按下发送按钮或点击输入他们的查询/问题时,这是我的路线:

    @app.route('/process_query', methods=['POST'])
    def process_query():
        if not agent:
            return jsonify({"error": "Agent not initialized"}), 500
        
        # Extract query from the POST request
        data = request.json
        query = data.get('query')
        
        if not query:
            return jsonify({"error": "No query provided"}), 400
        
        try:
            response = agent.invoke(chat_prompt.format_prompt(query=query, chat_history=memory.buffer_as_messages, context=context).to_messages())
            if response["intermediate_steps"]:
                    queries = extract_queries(response['intermediate_steps'])
            else:
                queries = None
            
            pattern = r'```python\s(.*?)```'
            code_snippets = re.findall(pattern, response["output"], re.DOTALL)
            
            if code_snippets:
                    cleaned_snippets = [snippet.replace("python\n", "", 1) for snippet in code_snippets]
            else:
                    cleaned_snippets = None
                
            graph = extract_graph_code(output_code=cleaned_snippets, queries=queries)
            
            clean_response = remove_code(response["output"])
            
            memory.chat_memory.add_user_message(query)
            memory.chat_memory.add_ai_message(clean_response)
            
            return jsonify({"response": clean_response,
                            "graph": graph})
        except Exception as e:
            return jsonify({"error": str(e)}), 500
    

    通常,如果表是简单读取且代理没有内存,那么将这些对象编码到 process_query() 函数中将非常容易。但是,我需要确保每次运行处理查询的请求时代理都不会初始化,因为我需要为每个用户提供持久内存。此外,我无法让每个请求都生成数据帧到 process_query(),因为为了抓取表并生成数据帧,平均需要 27 秒(我还没有找到简化 SQL 查询的方法,以便根据用户访问级别抓取和生成自定义表)。

    我需要找到一种方法,使所有对象不作为全局变量(数据框、代理、内存)。我尝试将数据框设为 Json 对象并将它们传递到会话 cookie 中,但仅一个数据框就已经太大了。我一直在尝试在线查找可以做到这一点的不同方法,但似乎找不到好的答案。我见过人们使用 Redis 进行内存存储,但我没有,希望找到替代方案(如果没有,请告诉我)。我计划将此应用程序托管为 Azure Web 应用程序;不确定这是否有助于找到解决我的问题的方法。

    我考虑过为每个用户创建一个唯一的、随机生成的 ID,并将代理、内存和数据帧存储在全局字典中,并以 ID 作为密钥来确保它们的安全。注销时,我会删除键值对,以确保字典不会变得太大,从而防止出现内存和空间问题。我非常怀疑这在生产环境中是否是最佳选择,但这是我目前能想到的全部。

    我如何获取每个用户的数据/对象并将它们传递到 Flask 中的不同路由?请让我知道我还能做什么或尝试什么!

  • 您可能需要确定哪些数据是全局的、每个用户的和短暂的(可以丢弃)。我认为您关于地图的想法对用户代码来说并不是很糟糕,只要您跟踪所涉及的内存并且不要过度即可。对于长时间运行的会话,您可以使用某种离线存储(如 redis、sqlite 或只是哑文件),但这一切都取决于数据的大小

返回
作者最近主题: