2025-03-18 08:46:50 +08:00

230 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
@Remark: 自定义视图集
"""
from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.viewsets import ModelViewSet
from utils.filters import DataLevelPermissionsFilter
from utils.jsonResponse import SuccessResponse,ErrorResponse,DetailResponse
from utils.permission import CustomPermission
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
from django.core.exceptions import ValidationError
from utils.exception import APIException
from django_filters.rest_framework import DjangoFilterBackend
from django_filters import utils
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.permissions import IsAuthenticated
from utils.export_excel2 import LyExportExcel
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to also raise 404
if the filter_kwargs don't match the required types.
"""
try:
return _get_object_or_404(queryset, *filter_args, **filter_kwargs)
except (TypeError, ValueError, ValidationError):
raise APIException(message='该对象不存在或者无访问权限')
class CustomDjangoFilterBackend(DjangoFilterBackend):
"""
自定义DjangoFilterBackend过滤重新支持filter_fieldsfilter_class
新版本django-filter==22.1开始弃用21.1版本及以下的filter_fieldsfilter_class
改为filterset_fields和filterset_class
"""
def get_filterset_class(self, view, queryset=None):
"""
Return the `FilterSet` class used to filter the queryset.
"""
filterset_class = getattr(view, 'filterset_class', None)
filterset_fields = getattr(view, 'filterset_fields', None)
# TODO: remove assertion in 2.1
if filterset_class is None and hasattr(view, 'filter_class'):
utils.deprecate(
"`%s.filter_class` attribute should be renamed `filterset_class`."
% view.__class__.__name__)
filterset_class = getattr(view, 'filter_class', None)
# TODO: remove assertion in 2.1
if filterset_fields is None and hasattr(view, 'filter_fields'):
utils.deprecate(
"`%s.filter_fields` attribute should be renamed `filterset_fields`."
% view.__class__.__name__)
filterset_fields = getattr(view, 'filter_fields', None)
if filterset_class:
filterset_model = filterset_class._meta.model
# FilterSets do not need to specify a Meta class
if filterset_model and queryset is not None:
assert issubclass(queryset.model, filterset_model), \
'FilterSet model %s does not match queryset model %s' % \
(filterset_model, queryset.model)
return filterset_class
if filterset_fields and queryset is not None:
MetaBase = getattr(self.filterset_base, 'Meta', object)
class AutoFilterSet(self.filterset_base):
class Meta(MetaBase):
model = queryset.model
fields = filterset_fields
return AutoFilterSet
return None
class CustomModelViewSet(ModelViewSet):
"""
自定义的ModelViewSet:
统一标准的返回格式;新增,查询,修改可使用不同序列化器
(1)create_serializer_class 新增时,使用的序列化器
(2)update_serializer_class 修改时,使用的序列化器
(3)export_serializer_class 导出数据,使用的序列化器,可为空,为空则调用默认序列化器
(4)export_field_dict = {} 导出时的字段,如export_field_dict = {'name':'姓名','age':'年龄'}
(5)export_download_mode 指定导出excel形式temp 内存型临时下载系统不保存文件内存文件流直接下载、url 下载链接系统保存文件返回http/https下载链接地址
(6)export_download_filename 指定导出excel文件名称可为空: 如export_download_filename = "导出用户数据"
"""
values_queryset = None
ordering_fields = '__all__'
create_serializer_class = None
update_serializer_class = None
export_serializer_class = None
export_field_dict = {}
export_download_mode = "temp"
export_download_filename = ""
filterset_fields = ()
# filterset_fields = '__all__'
search_fields = ()
extra_filter_backends = [DataLevelPermissionsFilter]
permission_classes = [CustomPermission,IsAuthenticated]
filter_backends = [CustomDjangoFilterBackend,SearchFilter,OrderingFilter]#对于想要提高一点性能的小伙伴可以覆盖filter_backends去除当前接口无用的过滤类
def filter_queryset(self, queryset):
for backend in (list(self.extra_filter_backends) + list(self.filter_backends)):
queryset = backend().filter_queryset(self.request, queryset, self)
return queryset
def get_queryset(self):
if getattr(self, 'values_queryset', None):
return self.values_queryset
return super().get_queryset()
def get_serializer_class(self):
action_serializer_name = f"{self.action}_serializer_class"
action_serializer_class = getattr(self, action_serializer_name, None)
if action_serializer_class:
return action_serializer_class
return super().get_serializer_class()
# 导出Excel
@action(methods=['post'], detail=False)
def export(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
request_data = request.data
ids = request_data.get('ids', None) # 存在导出的id数据列表则只导出指定的数据
if ids:
queryset = queryset.filter(id__in=ids)
export_serializer_class = self.export_serializer_class if self.export_serializer_class else self.serializer_class
data = export_serializer_class(queryset, many=True, request=request).data
if self.export_download_filename:
result = LyExportExcel(request=request, downloadMode=self.export_download_mode,
fileName=self.export_download_filename + ".xlsx").export_data(
self.export_field_dict, data)
else:
result = LyExportExcel(request=request, downloadMode=self.export_download_mode).export_data(
self.export_field_dict, data)
if self.export_download_mode == "temp":
return result
return DetailResponse(data=result, msg="导出成功")
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data, request=request)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
return DetailResponse(data=serializer.data, msg="新增成功")
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True, request=request)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True, request=request)
return SuccessResponse(data=serializer.data, msg="获取成功")
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return SuccessResponse(data=serializer.data, msg="获取成功")
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, request=request, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
return DetailResponse(data=serializer.data, msg="更新成功")
#增强drf得批量删除功能 http请求方法delete 如: url /api/admin/user/1,2,3/ 批量删除id 123得用户
def get_object_list(self):
queryset = self.filter_queryset(self.get_queryset())
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
assert lookup_url_kwarg in self.kwargs, (
'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' %
(self.__class__.__name__, lookup_url_kwarg)
)
filter_kwargs = {f"{self.lookup_field}__in": self.kwargs[lookup_url_kwarg].split(',')}
obj = queryset.filter(**filter_kwargs)
self.check_object_permissions(self.request, obj)
return obj
#重写delete方法让它支持批量删除 如: /api/admin/user/1,2,3/ 批量删除id 123得用户
def destroy(self, request, *args, **kwargs):
instance = self.get_object_list()
self.perform_destroy(instance)
return DetailResponse(data=[], msg="删除成功")
def perform_destroy(self, instance):
instance.delete()
#原来得单id删除方法
# def destroy(self, request, *args, **kwargs):
# instance = self.get_object()
# self.perform_destroy(instance)
# return SuccessResponse(data=[], msg="删除成功")
#新的批量删除方法
keys = openapi.Schema(description='主键列表', type=openapi.TYPE_ARRAY, items=openapi.TYPE_STRING)
@swagger_auto_schema(request_body=openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['keys'],
properties={'keys': keys}
), operation_summary='批量删除')
@action(methods=['delete'], detail=False)
def multiple_delete(self, request, *args, **kwargs):
request_data = request.data
keys = request_data.get('keys', None)
if keys:
self.get_queryset().filter(id__in=keys).delete()
return SuccessResponse(data=[], msg="删除成功")
else:
return ErrorResponse(msg="未获取到keys字段")