SpringBoot集成tensorflow實現(xiàn)圖片檢測功能
1.什么是tensorflow?
TensorFlow名字的由來就是張量(Tensor)在計算圖(Computational Graph)里的流動(Flow),如圖。它的基礎(chǔ)就是前面介紹的基于計算圖的自動微分,除了自動幫你求梯度之外,它也提供了各種常見的操作(op,也就是計算圖的節(jié)點),常見的損失函數(shù),優(yōu)化算法。

TensorFlow 是一個開放源代碼軟件庫,用于進行高性能數(shù)值計算。借助其靈活的架構(gòu),用戶可以輕松地將計算工作部署到多種平臺(CPU、GPU、TPU)和設(shè)備(桌面設(shè)備、服務(wù)器集群、移動設(shè)備、邊緣設(shè)備等)。
TensorFlow 是一個用于研究和生產(chǎn)的開放源代碼機器學習庫。TensorFlow 提供了各種 API,可供初學者和專家在桌面、移動、網(wǎng)絡(luò)和云端環(huán)境下進行開發(fā)。
TensorFlow是采用數(shù)據(jù)流圖(data flow graphs)來計算,所以首先我們得創(chuàng)建一個數(shù)據(jù)流流圖,然后再將我們的數(shù)據(jù)(數(shù)據(jù)以張量(tensor)的形式存在)放在數(shù)據(jù)流圖中計算. 節(jié)點(Nodes)在圖中表示數(shù)學操作,圖中的邊(edges)則表示在節(jié)點間相互聯(lián)系的多維數(shù)據(jù)數(shù)組, 即張量(tensor)。訓(xùn)練模型時tensor會不斷的從數(shù)據(jù)流圖中的一個節(jié)點flow到另一節(jié)點, 這就是TensorFlow名字的由來。 張量(Tensor):張量有多種. 零階張量為 純量或標量 (scalar) 也就是一個數(shù)值. 比如 [1],一階張量為 向量 (vector), 比如 一維的 [1, 2, 3],二階張量為 矩陣 (matrix), 比如 二維的 [[1, 2, 3],[4, 5, 6],[7, 8, 9]],以此類推, 還有 三階 三維的 … 張量從流圖的一端流動到另一端的計算過程。它生動形象地描述了復(fù)雜數(shù)據(jù)結(jié)構(gòu)在人工神經(jīng)網(wǎng)中的流動、傳輸、分析和處理模式。
在機器學習中,數(shù)值通常由4種類型構(gòu)成: (1)標量(scalar):即一個數(shù)值,它是計算的最小單元,如“1”或“3.2”等。 (2)向量(vector):由一些標量構(gòu)成的一維數(shù)組,如[1, 3.2, 4.6]等。 (3)矩陣(matrix):是由標量構(gòu)成的二維數(shù)組。 (4)張量(tensor):由多維(通常)數(shù)組構(gòu)成的數(shù)據(jù)集合,可理解為高維矩陣。
tensorflow的基本概念
- 圖:描述了計算過程,Tensorflow用圖來表示計算過程
- 張量:Tensorflow 使用tensor表示數(shù)據(jù),每一個tensor是一個多維化的數(shù)組
- 操作:圖中的節(jié)點為op,一個op獲得/輸入0個或者多個Tensor,執(zhí)行并計算,產(chǎn)生0個或多個Tensor
- 會話:session tensorflow的運行需要再繪話里面運行
tensorflow寫代碼流程
- 定義變量占位符
- 根據(jù)數(shù)學原理寫方程
- 定義損失函數(shù)cost
- 定義優(yōu)化梯度下降 GradientDescentOptimizer
- session 進行訓(xùn)練,for循環(huán)
- 保存saver
2.環(huán)境準備
整合步驟
- 模型構(gòu)建:首先,我們需要在TensorFlow中定義并訓(xùn)練深度學習模型。這可能涉及選擇合適的網(wǎng)絡(luò)結(jié)構(gòu)、優(yōu)化器和損失函數(shù)等。
- 訓(xùn)練數(shù)據(jù)準備:接下來,我們需要準備用于訓(xùn)練和驗證模型的數(shù)據(jù)。這可能包括數(shù)據(jù)清洗、標注和預(yù)處理等步驟。
- REST API設(shè)計:為了與TensorFlow模型進行交互,我們需要在SpringBoot中創(chuàng)建一個REST API。這可以使用SpringBoot的內(nèi)置功能來實現(xiàn),例如使用Spring MVC或Spring WebFlux。
- 模型部署:在模型訓(xùn)練完成后,我們需要將其部署到SpringBoot應(yīng)用中。為此,我們可以使用TensorFlow的Java API將模型導(dǎo)出為ONNX或SavedModel格式,然后在SpringBoot應(yīng)用中加載并使用。
在整合過程中,有幾個關(guān)鍵點需要注意。首先,防火墻設(shè)置可能會影響TensorFlow訓(xùn)練過程中的網(wǎng)絡(luò)通信。確保你的防火墻允許TensorFlow訪問其所需的網(wǎng)絡(luò)資源,以免出現(xiàn)訓(xùn)練中斷或模型性能下降的問題。其次,要關(guān)注版本兼容性。SpringBoot和TensorFlow都有各自的版本更新周期,確保在整合時使用兼容的版本可以避免很多不必要的麻煩。
3.代碼工程
實驗?zāi)康?/h3>
實現(xiàn)圖片檢測
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>springboot-demo</artifactId>
<groupId>com.et</groupId>
<version>1.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>Tensorflow</artifactId>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-autoconfigure</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>0.5.0</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>jmimemagic</groupId>
<artifactId>jmimemagic</artifactId>
<version>0.1.2</version>
</dependency>
<dependency>
<groupId>jakarta.platform</groupId>
<artifactId>jakarta.jakartaee-api</artifactId>
<version>9.0.0</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.16.1</version>
</dependency>
<dependency>
<groupId>org.springframework.restdocs</groupId>
<artifactId>spring-restdocs-mockmvc</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
controller
package com.et.tf.api;
import java.io.IOException;
import com.et.tf.service.ClassifyImageService;
import net.sf.jmimemagic.Magic;
import net.sf.jmimemagic.MagicMatch;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
@RestController
@RequestMapping("/api")
public class AppController {
@Autowired
ClassifyImageService classifyImageService;
@PostMapping(value = "/classify")
@CrossOrigin(origins = "*")
public ClassifyImageService.LabelWithProbability classifyImage(@RequestParam MultipartFile file) throws IOException {
checkImageContents(file);
return classifyImageService.classifyImage(file.getBytes());
}
@RequestMapping(value = "/")
public String index() {
return "index";
}
private void checkImageContents(MultipartFile file) {
MagicMatch match;
try {
match = Magic.getMagicMatch(file.getBytes());
} catch (Exception e) {
throw new RuntimeException(e);
}
String mimeType = match.getMimeType();
if (!mimeType.startsWith("image")) {
throw new IllegalArgumentException("Not an image type: " + mimeType);
}
}
}
service
package com.et.tf.service;
import jakarta.annotation.PreDestroy;
import java.util.Arrays;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.op.OpScope;
import org.tensorflow.op.Scope;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;
//Inspired from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
@Service
@Slf4j
public class ClassifyImageService {
private final Session session;
private final List<String> labels;
private final String outputLayer;
private final int W;
private final int H;
private final float mean;
private final float scale;
public ClassifyImageService(
Graph inceptionGraph, List<String> labels, @Value("${tf.outputLayer}") String outputLayer,
@Value("${tf.image.width}") int imageW, @Value("${tf.image.height}") int imageH,
@Value("${tf.image.mean}") float mean, @Value("${tf.image.scale}") float scale
) {
this.labels = labels;
this.outputLayer = outputLayer;
this.H = imageH;
this.W = imageW;
this.mean = mean;
this.scale = scale;
this.session = new Session(inceptionGraph);
}
public LabelWithProbability classifyImage(byte[] imageBytes) {
long start = System.currentTimeMillis();
try (Tensor image = normalizedImageToTensor(imageBytes)) {
float[] labelProbabilities = classifyImageProbabilities(image);
int bestLabelIdx = maxIndex(labelProbabilities);
LabelWithProbability labelWithProbability =
new LabelWithProbability(labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f, System.currentTimeMillis() - start);
log.debug(String.format(
"Image classification [%s %.2f%%] took %d ms",
labelWithProbability.getLabel(),
labelWithProbability.getProbability(),
labelWithProbability.getElapsed()
)
);
return labelWithProbability;
}
}
private float[] classifyImageProbabilities(Tensor image) {
try (Tensor result = session.runner().feed("input", image).fetch(outputLayer).run().get(0)) {
final Shape resultShape = result.shape();
final long[] rShape = resultShape.asArray();
if (resultShape.numDimensions() != 2 || rShape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rShape)
));
}
int nlabels = (int) rShape[1];
FloatDataBuffer resultFloatBuffer = result.asRawTensor().data().asFloats();
float[] dst = new float[nlabels];
resultFloatBuffer.read(dst);
return dst;
}
}
private int maxIndex(float[] probabilities) {
int best = 0;
for (int i = 1; i < probabilities.length; ++i) {
if (probabilities[i] > probabilities[best]) {
best = i;
}
}
return best;
}
private Tensor normalizedImageToTensor(byte[] imageBytes) {
try (Graph g = new Graph();
TInt32 batchTensor = TInt32.scalarOf(0);
TInt32 sizeTensor = TInt32.vectorOf(H, W);
TFloat32 meanTensor = TFloat32.scalarOf(mean);
TFloat32 scaleTensor = TFloat32.scalarOf(scale);
) {
GraphBuilder b = new GraphBuilder(g);
//Tutorial python here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/label_image
// Some constants specific to the pre-trained model at:
// https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz
//
// - The model was trained with images scaled to 299x299 pixels.
// - The colors, represented as R, G, B in 1-byte each were converted to
// float using (value - Mean)/Scale.
// Since the graph is being constructed once per execution here, we can use a constant for the
// input image. If the graph were to be re-used for multiple input images, a placeholder would
// have been more appropriate.
final Output input = b.constant("input", TString.tensorOfBytes(NdArrays.scalarOfObject(imageBytes)));
final Output output =
b.div(
b.sub(
b.resizeBilinear(
b.expandDims(
b.cast(b.decodeJpeg(input, 3), DataType.DT_FLOAT),
b.constant("make_batch", batchTensor)
),
b.constant("size", sizeTensor)
),
b.constant("mean", meanTensor)
),
b.constant("scale", scaleTensor)
);
try (Session s = new Session(g)) {
return s.runner().fetch(output.op().name()).run().get(0);
}
}
}
static class GraphBuilder {
final Scope scope;
GraphBuilder(Graph g) {
this.g = g;
this.scope = new OpScope(g);
}
Output div(Output x, Output y) {
return binaryOp("Div", x, y);
}
Output sub(Output x, Output y) {
return binaryOp("Sub", x, y);
}
Output resizeBilinear(Output images, Output size) {
return binaryOp("ResizeBilinear", images, size);
}
Output expandDims(Output input, Output dim) {
return binaryOp("ExpandDims", input, dim);
}
Output cast(Output value, DataType dtype) {
return g.opBuilder("Cast", "Cast", scope).addInput(value).setAttr("DstT", dtype).build().output(0);
}
Output decodeJpeg(Output contents, long channels) {
return g.opBuilder("DecodeJpeg", "DecodeJpeg", scope)
.addInput(contents)
.setAttr("channels", channels)
.build()
.output(0);
}
Output<? extends TType> constant(String name, Tensor t) {
return g.opBuilder("Const", name, scope)
.setAttr("dtype", t.dataType())
.setAttr("value", t)
.build()
.output(0);
}
private Output binaryOp(String type, Output in1, Output in2) {
return g.opBuilder(type, type, scope).addInput(in1).addInput(in2).build().output(0);
}
private final Graph g;
}
@PreDestroy
public void close() {
session.close();
}
@Data
@NoArgsConstructor
@AllArgsConstructor
public static class LabelWithProbability {
private String label;
private float probability;
private long elapsed;
}
}
application.yaml
tf:
frozenModelPath: inception-v3/inception_v3_2016_08_28_frozen.pb
labelsPath: inception-v3/imagenet_slim_labels.txt
outputLayer: InceptionV3/Predictions/Reshape_1
image:
width: 299
height: 299
mean: 0
scale: 255
logging.level.net.sf.jmimemagic: WARN
spring:
servlet:
multipart:
max-file-size: 5MB
Application.java
package com.et.tf;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.tensorflow.Graph;
import org.tensorflow.proto.framework.GraphDef;
@SpringBootApplication
@Slf4j
public class Application {
public static void main(String[] args) {
SpringApplication.run(Application.class, args);
}
@Bean
public Graph tfModelGraph(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) throws IOException {
Resource graphResource = getResource(tfFrozenModelPath);
Graph graph = new Graph();
graph.importGraphDef(GraphDef.parseFrom(graphResource.getInputStream()));
log.info("Loaded Tensorflow model");
return graph;
}
private Resource getResource(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) {
Resource graphResource = new FileSystemResource(tfFrozenModelPath);
if (!graphResource.exists()) {
graphResource = new ClassPathResource(tfFrozenModelPath);
}
if (!graphResource.exists()) {
throw new IllegalArgumentException(String.format("File %s does not exist", tfFrozenModelPath));
}
return graphResource;
}
@Bean
public List<String> tfModelLabels(@Value("${tf.labelsPath}") String labelsPath) throws IOException {
Resource labelsRes = getResource(labelsPath);
log.info("Loaded model labels");
return IOUtils.readLines(labelsRes.getInputStream(), StandardCharsets.UTF_8).stream()
.map(label -> label.substring(label.contains(":") ? label.indexOf(":") + 1 : 0)).collect(Collectors.toList());
}
}
以上只是一些關(guān)鍵代碼,所有代碼請參見下面代碼倉庫
代碼倉庫
https://github.com/Harries/springboot-demo
4.測試
啟動 Spring Boot應(yīng)用程序
測試圖片分類
訪問http://127.0.0.1:8080/,上傳一張圖片,點擊分類

5.總結(jié)
以上就是SpringBoot集成tensorflow實現(xiàn)圖片檢測功能的詳細內(nèi)容,更多關(guān)于SpringBoot tensorflow圖片檢測的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
SpringBoot集成Devtools實現(xiàn)熱更新
DevTools是開發(fā)者工具集,主要用于簡化開發(fā)過程中的熱部署問題,熱部署是指在開發(fā)過程中,當代碼發(fā)生變化時,無需手動重啟應(yīng)用,系統(tǒng)能夠自動檢測并重新加載修改后的代碼,本文給大家介紹了SpringBoot集成Devtools實現(xiàn)熱更新,需要的朋友可以參考下2024-08-08
Java數(shù)據(jù)結(jié)構(gòu)超詳細分析二叉搜索樹
二叉搜索樹是以一棵二叉樹來組織的。每個節(jié)點是一個對象,包含的屬性有l(wèi)eft,right,p和key,其中,left指向該節(jié)點的左孩子,right指向該節(jié)點的右孩子,p指向該節(jié)點的父節(jié)點,key是它的值2022-03-03
MybatisPlus將自定義的sql列表查詢返回改為分頁查詢
本文主要介紹了MybatisPlus將自定義的sql列表查詢返回改為分頁查詢,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2024-04-04
idea中javaweb的jsp頁面圖片加載不出來問題及解決
這篇文章主要介紹了idea中javaweb的jsp頁面圖片加載不出來問題及解決方案,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-07-07
Spring?Boot項目如何優(yōu)雅實現(xiàn)Excel導(dǎo)入與導(dǎo)出功能
在我們平時工作中經(jīng)常會遇到要操作Excel的功能,比如導(dǎo)出個用戶信息或者訂單信息的Excel報表,下面這篇文章主要給大家介紹了關(guān)于Spring?Boot項目中如何優(yōu)雅實現(xiàn)Excel導(dǎo)入與導(dǎo)出功能的相關(guān)資料,需要的朋友可以參考下2022-06-06

