SpringBoot集成tensorflow實(shí)現(xiàn)圖片檢測(cè)功能
1.什么是tensorflow?
TensorFlow名字的由來就是張量(Tensor)在計(jì)算圖(Computational Graph)里的流動(dòng)(Flow),如圖。它的基礎(chǔ)就是前面介紹的基于計(jì)算圖的自動(dòng)微分,除了自動(dòng)幫你求梯度之外,它也提供了各種常見的操作(op,也就是計(jì)算圖的節(jié)點(diǎn)),常見的損失函數(shù),優(yōu)化算法。
TensorFlow 是一個(gè)開放源代碼軟件庫,用于進(jìn)行高性能數(shù)值計(jì)算。借助其靈活的架構(gòu),用戶可以輕松地將計(jì)算工作部署到多種平臺(tái)(CPU、GPU、TPU)和設(shè)備(桌面設(shè)備、服務(wù)器集群、移動(dòng)設(shè)備、邊緣設(shè)備等)。
TensorFlow 是一個(gè)用于研究和生產(chǎn)的開放源代碼機(jī)器學(xué)習(xí)庫。TensorFlow 提供了各種 API,可供初學(xué)者和專家在桌面、移動(dòng)、網(wǎng)絡(luò)和云端環(huán)境下進(jìn)行開發(fā)。
TensorFlow是采用數(shù)據(jù)流圖(data flow graphs)來計(jì)算,所以首先我們得創(chuàng)建一個(gè)數(shù)據(jù)流流圖,然后再將我們的數(shù)據(jù)(數(shù)據(jù)以張量(tensor)的形式存在)放在數(shù)據(jù)流圖中計(jì)算. 節(jié)點(diǎn)(Nodes)在圖中表示數(shù)學(xué)操作,圖中的邊(edges)則表示在節(jié)點(diǎn)間相互聯(lián)系的多維數(shù)據(jù)數(shù)組, 即張量(tensor)。訓(xùn)練模型時(shí)tensor會(huì)不斷的從數(shù)據(jù)流圖中的一個(gè)節(jié)點(diǎn)flow到另一節(jié)點(diǎn), 這就是TensorFlow名字的由來。 張量(Tensor):張量有多種. 零階張量為 純量或標(biāo)量 (scalar) 也就是一個(gè)數(shù)值. 比如 [1],一階張量為 向量 (vector), 比如 一維的 [1, 2, 3],二階張量為 矩陣 (matrix), 比如 二維的 [[1, 2, 3],[4, 5, 6],[7, 8, 9]],以此類推, 還有 三階 三維的 … 張量從流圖的一端流動(dòng)到另一端的計(jì)算過程。它生動(dòng)形象地描述了復(fù)雜數(shù)據(jù)結(jié)構(gòu)在人工神經(jīng)網(wǎng)中的流動(dòng)、傳輸、分析和處理模式。
在機(jī)器學(xué)習(xí)中,數(shù)值通常由4種類型構(gòu)成: (1)標(biāo)量(scalar):即一個(gè)數(shù)值,它是計(jì)算的最小單元,如“1”或“3.2”等。 (2)向量(vector):由一些標(biāo)量構(gòu)成的一維數(shù)組,如[1, 3.2, 4.6]等。 (3)矩陣(matrix):是由標(biāo)量構(gòu)成的二維數(shù)組。 (4)張量(tensor):由多維(通常)數(shù)組構(gòu)成的數(shù)據(jù)集合,可理解為高維矩陣。
tensorflow的基本概念
- 圖:描述了計(jì)算過程,Tensorflow用圖來表示計(jì)算過程
- 張量:Tensorflow 使用tensor表示數(shù)據(jù),每一個(gè)tensor是一個(gè)多維化的數(shù)組
- 操作:圖中的節(jié)點(diǎn)為op,一個(gè)op獲得/輸入0個(gè)或者多個(gè)Tensor,執(zhí)行并計(jì)算,產(chǎn)生0個(gè)或多個(gè)Tensor
- 會(huì)話:session tensorflow的運(yùn)行需要再繪話里面運(yùn)行
tensorflow寫代碼流程
- 定義變量占位符
- 根據(jù)數(shù)學(xué)原理寫方程
- 定義損失函數(shù)cost
- 定義優(yōu)化梯度下降 GradientDescentOptimizer
- session 進(jìn)行訓(xùn)練,for循環(huán)
- 保存saver
2.環(huán)境準(zhǔn)備
整合步驟
- 模型構(gòu)建:首先,我們需要在TensorFlow中定義并訓(xùn)練深度學(xué)習(xí)模型。這可能涉及選擇合適的網(wǎng)絡(luò)結(jié)構(gòu)、優(yōu)化器和損失函數(shù)等。
- 訓(xùn)練數(shù)據(jù)準(zhǔn)備:接下來,我們需要準(zhǔn)備用于訓(xùn)練和驗(yàn)證模型的數(shù)據(jù)。這可能包括數(shù)據(jù)清洗、標(biāo)注和預(yù)處理等步驟。
- REST API設(shè)計(jì):為了與TensorFlow模型進(jìn)行交互,我們需要在SpringBoot中創(chuàng)建一個(gè)REST API。這可以使用SpringBoot的內(nèi)置功能來實(shí)現(xiàn),例如使用Spring MVC或Spring WebFlux。
- 模型部署:在模型訓(xùn)練完成后,我們需要將其部署到SpringBoot應(yīng)用中。為此,我們可以使用TensorFlow的Java API將模型導(dǎo)出為ONNX或SavedModel格式,然后在SpringBoot應(yīng)用中加載并使用。
在整合過程中,有幾個(gè)關(guān)鍵點(diǎn)需要注意。首先,防火墻設(shè)置可能會(huì)影響TensorFlow訓(xùn)練過程中的網(wǎng)絡(luò)通信。確保你的防火墻允許TensorFlow訪問其所需的網(wǎng)絡(luò)資源,以免出現(xiàn)訓(xùn)練中斷或模型性能下降的問題。其次,要關(guān)注版本兼容性。SpringBoot和TensorFlow都有各自的版本更新周期,確保在整合時(shí)使用兼容的版本可以避免很多不必要的麻煩。
3.代碼工程
實(shí)驗(yàn)?zāi)康?/h3>
實(shí)現(xiàn)圖片檢測(cè)
pom.xml
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <parent> <artifactId>springboot-demo</artifactId> <groupId>com.et</groupId> <version>1.0-SNAPSHOT</version> </parent> <modelVersion>4.0.0</modelVersion> <artifactId>Tensorflow</artifactId> <properties> <maven.compiler.source>11</maven.compiler.source> <maven.compiler.target>11</maven.compiler.target> </properties> <dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-autoconfigure</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-test</artifactId> <scope>test</scope> </dependency> <dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow-core-platform</artifactId> <version>0.5.0</version> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> </dependency> <dependency> <groupId>jmimemagic</groupId> <artifactId>jmimemagic</artifactId> <version>0.1.2</version> </dependency> <dependency> <groupId>jakarta.platform</groupId> <artifactId>jakarta.jakartaee-api</artifactId> <version>9.0.0</version> </dependency> <dependency> <groupId>commons-io</groupId> <artifactId>commons-io</artifactId> <version>2.16.1</version> </dependency> <dependency> <groupId>org.springframework.restdocs</groupId> <artifactId>spring-restdocs-mockmvc</artifactId> <scope>test</scope> </dependency> </dependencies> </project>
controller
package com.et.tf.api; import java.io.IOException; import com.et.tf.service.ClassifyImageService; import net.sf.jmimemagic.Magic; import net.sf.jmimemagic.MagicMatch; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.CrossOrigin; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; @RestController @RequestMapping("/api") public class AppController { @Autowired ClassifyImageService classifyImageService; @PostMapping(value = "/classify") @CrossOrigin(origins = "*") public ClassifyImageService.LabelWithProbability classifyImage(@RequestParam MultipartFile file) throws IOException { checkImageContents(file); return classifyImageService.classifyImage(file.getBytes()); } @RequestMapping(value = "/") public String index() { return "index"; } private void checkImageContents(MultipartFile file) { MagicMatch match; try { match = Magic.getMagicMatch(file.getBytes()); } catch (Exception e) { throw new RuntimeException(e); } String mimeType = match.getMimeType(); if (!mimeType.startsWith("image")) { throw new IllegalArgumentException("Not an image type: " + mimeType); } } }
service
package com.et.tf.service; import jakarta.annotation.PreDestroy; import java.util.Arrays; import java.util.List; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.op.OpScope; import org.tensorflow.op.Scope; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; //Inspired from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java @Service @Slf4j public class ClassifyImageService { private final Session session; private final List<String> labels; private final String outputLayer; private final int W; private final int H; private final float mean; private final float scale; public ClassifyImageService( Graph inceptionGraph, List<String> labels, @Value("${tf.outputLayer}") String outputLayer, @Value("${tf.image.width}") int imageW, @Value("${tf.image.height}") int imageH, @Value("${tf.image.mean}") float mean, @Value("${tf.image.scale}") float scale ) { this.labels = labels; this.outputLayer = outputLayer; this.H = imageH; this.W = imageW; this.mean = mean; this.scale = scale; this.session = new Session(inceptionGraph); } public LabelWithProbability classifyImage(byte[] imageBytes) { long start = System.currentTimeMillis(); try (Tensor image = normalizedImageToTensor(imageBytes)) { float[] labelProbabilities = classifyImageProbabilities(image); int bestLabelIdx = maxIndex(labelProbabilities); LabelWithProbability labelWithProbability = new LabelWithProbability(labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f, System.currentTimeMillis() - start); log.debug(String.format( "Image classification [%s %.2f%%] took %d ms", labelWithProbability.getLabel(), labelWithProbability.getProbability(), labelWithProbability.getElapsed() ) ); return labelWithProbability; } } private float[] classifyImageProbabilities(Tensor image) { try (Tensor result = session.runner().feed("input", image).fetch(outputLayer).run().get(0)) { final Shape resultShape = result.shape(); final long[] rShape = resultShape.asArray(); if (resultShape.numDimensions() != 2 || rShape[0] != 1) { throw new RuntimeException( String.format( "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rShape) )); } int nlabels = (int) rShape[1]; FloatDataBuffer resultFloatBuffer = result.asRawTensor().data().asFloats(); float[] dst = new float[nlabels]; resultFloatBuffer.read(dst); return dst; } } private int maxIndex(float[] probabilities) { int best = 0; for (int i = 1; i < probabilities.length; ++i) { if (probabilities[i] > probabilities[best]) { best = i; } } return best; } private Tensor normalizedImageToTensor(byte[] imageBytes) { try (Graph g = new Graph(); TInt32 batchTensor = TInt32.scalarOf(0); TInt32 sizeTensor = TInt32.vectorOf(H, W); TFloat32 meanTensor = TFloat32.scalarOf(mean); TFloat32 scaleTensor = TFloat32.scalarOf(scale); ) { GraphBuilder b = new GraphBuilder(g); //Tutorial python here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/label_image // Some constants specific to the pre-trained model at: // https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz // // - The model was trained with images scaled to 299x299 pixels. // - The colors, represented as R, G, B in 1-byte each were converted to // float using (value - Mean)/Scale. // Since the graph is being constructed once per execution here, we can use a constant for the // input image. If the graph were to be re-used for multiple input images, a placeholder would // have been more appropriate. final Output input = b.constant("input", TString.tensorOfBytes(NdArrays.scalarOfObject(imageBytes))); final Output output = b.div( b.sub( b.resizeBilinear( b.expandDims( b.cast(b.decodeJpeg(input, 3), DataType.DT_FLOAT), b.constant("make_batch", batchTensor) ), b.constant("size", sizeTensor) ), b.constant("mean", meanTensor) ), b.constant("scale", scaleTensor) ); try (Session s = new Session(g)) { return s.runner().fetch(output.op().name()).run().get(0); } } } static class GraphBuilder { final Scope scope; GraphBuilder(Graph g) { this.g = g; this.scope = new OpScope(g); } Output div(Output x, Output y) { return binaryOp("Div", x, y); } Output sub(Output x, Output y) { return binaryOp("Sub", x, y); } Output resizeBilinear(Output images, Output size) { return binaryOp("ResizeBilinear", images, size); } Output expandDims(Output input, Output dim) { return binaryOp("ExpandDims", input, dim); } Output cast(Output value, DataType dtype) { return g.opBuilder("Cast", "Cast", scope).addInput(value).setAttr("DstT", dtype).build().output(0); } Output decodeJpeg(Output contents, long channels) { return g.opBuilder("DecodeJpeg", "DecodeJpeg", scope) .addInput(contents) .setAttr("channels", channels) .build() .output(0); } Output<? extends TType> constant(String name, Tensor t) { return g.opBuilder("Const", name, scope) .setAttr("dtype", t.dataType()) .setAttr("value", t) .build() .output(0); } private Output binaryOp(String type, Output in1, Output in2) { return g.opBuilder(type, type, scope).addInput(in1).addInput(in2).build().output(0); } private final Graph g; } @PreDestroy public void close() { session.close(); } @Data @NoArgsConstructor @AllArgsConstructor public static class LabelWithProbability { private String label; private float probability; private long elapsed; } }
application.yaml
tf: frozenModelPath: inception-v3/inception_v3_2016_08_28_frozen.pb labelsPath: inception-v3/imagenet_slim_labels.txt outputLayer: InceptionV3/Predictions/Reshape_1 image: width: 299 height: 299 mean: 0 scale: 255 logging.level.net.sf.jmimemagic: WARN spring: servlet: multipart: max-file-size: 5MB
Application.java
package com.et.tf; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.context.annotation.Bean; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import org.tensorflow.Graph; import org.tensorflow.proto.framework.GraphDef; @SpringBootApplication @Slf4j public class Application { public static void main(String[] args) { SpringApplication.run(Application.class, args); } @Bean public Graph tfModelGraph(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) throws IOException { Resource graphResource = getResource(tfFrozenModelPath); Graph graph = new Graph(); graph.importGraphDef(GraphDef.parseFrom(graphResource.getInputStream())); log.info("Loaded Tensorflow model"); return graph; } private Resource getResource(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) { Resource graphResource = new FileSystemResource(tfFrozenModelPath); if (!graphResource.exists()) { graphResource = new ClassPathResource(tfFrozenModelPath); } if (!graphResource.exists()) { throw new IllegalArgumentException(String.format("File %s does not exist", tfFrozenModelPath)); } return graphResource; } @Bean public List<String> tfModelLabels(@Value("${tf.labelsPath}") String labelsPath) throws IOException { Resource labelsRes = getResource(labelsPath); log.info("Loaded model labels"); return IOUtils.readLines(labelsRes.getInputStream(), StandardCharsets.UTF_8).stream() .map(label -> label.substring(label.contains(":") ? label.indexOf(":") + 1 : 0)).collect(Collectors.toList()); } }
以上只是一些關(guān)鍵代碼,所有代碼請(qǐng)參見下面代碼倉庫
代碼倉庫
https://github.com/Harries/springboot-demo
4.測(cè)試
啟動(dòng) Spring Boot應(yīng)用程序
測(cè)試圖片分類
訪問http://127.0.0.1:8080/,上傳一張圖片,點(diǎn)擊分類
5.總結(jié)
以上就是SpringBoot集成tensorflow實(shí)現(xiàn)圖片檢測(cè)功能的詳細(xì)內(nèi)容,更多關(guān)于SpringBoot tensorflow圖片檢測(cè)的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
SpringBoot集成Devtools實(shí)現(xiàn)熱更新
DevTools是開發(fā)者工具集,主要用于簡化開發(fā)過程中的熱部署問題,熱部署是指在開發(fā)過程中,當(dāng)代碼發(fā)生變化時(shí),無需手動(dòng)重啟應(yīng)用,系統(tǒng)能夠自動(dòng)檢測(cè)并重新加載修改后的代碼,本文給大家介紹了SpringBoot集成Devtools實(shí)現(xiàn)熱更新,需要的朋友可以參考下2024-08-08詳解如何給Sprintboot應(yīng)用添加插件機(jī)制
這篇文章主要為大家介紹了如何給 Sprintboot 應(yīng)用添加插件機(jī)制,文中有詳細(xì)的解決方案及示例代碼,具有一定的參考價(jià)值,需要的朋友可以參考下2023-08-08Java數(shù)據(jù)結(jié)構(gòu)超詳細(xì)分析二叉搜索樹
二叉搜索樹是以一棵二叉樹來組織的。每個(gè)節(jié)點(diǎn)是一個(gè)對(duì)象,包含的屬性有l(wèi)eft,right,p和key,其中,left指向該節(jié)點(diǎn)的左孩子,right指向該節(jié)點(diǎn)的右孩子,p指向該節(jié)點(diǎn)的父節(jié)點(diǎn),key是它的值2022-03-03MybatisPlus將自定義的sql列表查詢返回改為分頁查詢
本文主要介紹了MybatisPlus將自定義的sql列表查詢返回改為分頁查詢,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2024-04-04idea中javaweb的jsp頁面圖片加載不出來問題及解決
這篇文章主要介紹了idea中javaweb的jsp頁面圖片加載不出來問題及解決方案,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-07-07Spring?Boot項(xiàng)目如何優(yōu)雅實(shí)現(xiàn)Excel導(dǎo)入與導(dǎo)出功能
在我們平時(shí)工作中經(jīng)常會(huì)遇到要操作Excel的功能,比如導(dǎo)出個(gè)用戶信息或者訂單信息的Excel報(bào)表,下面這篇文章主要給大家介紹了關(guān)于Spring?Boot項(xiàng)目中如何優(yōu)雅實(shí)現(xiàn)Excel導(dǎo)入與導(dǎo)出功能的相關(guān)資料,需要的朋友可以參考下2022-06-06